Refine docs and readme.md, refactor nn_generator (#70)

This commit is contained in:
Jiahang Xu 2022-06-14 15:52:32 +08:00 коммит произвёл GitHub
Родитель d5d4cd92a2
Коммит 25a53d89f4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
22 изменённых файлов: 288 добавлений и 219 удалений

Просмотреть файл

@ -17,7 +17,7 @@ The current supported hardware and inference frameworks:
- Those who want to get the DNN inference latency on mobile and edge devices with **no deployment efforts on real devices**.
- Those who want to run **hardware-aware NAS with [NNI](https://github.com/microsoft/nni)**.
- Those who want to **build latency predictors for their own devices**.
- Those who want to **build latency predictors for their own devices** ([Documents](https://github.com/microsoft/nn-Meter/blob/main/docs/builder/overview.md) of nn-Meter builder).
- Those who want to use the 26k latency [benchmark dataset](https://github.com/microsoft/nn-Meter/releases/download/v1.0-data/datasets.zip).
# Installation
@ -176,6 +176,17 @@ Users could view the information all built-in predictors by `list_latency_predic
Users could get a nn-Meter IR graph by applying `model_file_to_graph` and `model_to_graph` by calling the model name or model object and specify the model type. The supporting model types of `model_file_to_graph` include "onnx", "pb", "torch", "nnmeter-ir" and "nni-ir", while the supporting model types of `model_to_graph` include "onnx", "torch" and "nni-ir".
## nn-Meter Builder
nn-Meter builder is an open source tool for users to build latency predictor on their own devices. There are three main parts in nn-Meter builder:
**backend**: the module of connecting backends;
**backend_meta**: the meta tools related to backend. Here we provide the fusion rule tester to detect fusion rules for users' backend;
**kernel_predictor_builder**: the tool to build different kernel latency predictors.
Users could get access to nn-Meter builder by calling `nn_meter.builder`. For more details to use nn-Meter builder, please check the document of [nn-Meter builder](https://github.com/microsoft/nn-Meter/blob/main/docs/builder/overview.md).
## Hardware-aware NAS by nn-Meter and NNI

Просмотреть файл

@ -1,12 +1,40 @@
# Build Kernel Latency Predictor
## Step1: Config Sampling From Prior Distribution
## Step1: Prepare Backends and Create Workspace
The first step to build kernel latency predictor is to prepare backends and create workspace. Users could follow guidance [Prepare Backends](./prepare_backend.md) and [Create Workspace](./overview.md#create-workspace) for this step.
After creating the workspace, a yaml file named `predictorbuild_config.yaml` will be placed in `<workspace-path>/configs/`. The predictor build configs includes:
- `DETAIL`: Whether to attach detail information to the json output, such as the shape and configuration information in profiled results. Default value is `FALSE`.
- `IMPLEMENT`: The code implementation, could be chosen from [`tensorflow`, `torch`].
- `BATCH_SIZE`: The batch size in kernel profiling. Default value is 1.
- `KERNELS`: The training parameters for each kernel. By default, nn-Meter set 16 kernels, including "conv-bn-relu", "dwconv-bn-relu", "maxpool", "avgpool", "fc", "concat", "split", "channelshuffle", "se", "global-avgpool", "bnrelu", "bn", "hswish", "relu", "addrelu", "add". For each type of kernel, the parameters includes:
- `INIT_SAMPLE_NUM`: the data size for predictor initialization.
- `FINEGRAINED_SAMPLE_NUM`: the data size for adaptive sampling. For each data with error higher than error_threshold, number of `FINEGRAINED_SAMPLE_NUM` data will be generated based the the large error data. Defaults to 20.
- `ITERATION`: the iteration for sampling and training. Predictor training based on initial sampling is regarded as iteration 1, thus `iteration == 2` means one iteration for adaptive sampling.
- `ERROR_THRESHOLD`: the threshold of large error. Defaults to 0.1.
Users could open `<workspace-path>/configs/predictorbuild_config.yaml` and edit the content. After completing configuration, users could initialize workspace in `builder_config` module before building the kernel latency predictor:
```python
from nn_meter.builder import builder_config
# initialize builder config with workspace
builder_config.init(
workspace_path="path/to/workspace/folder"
) # change the text to required platform type and workspace path
```
Note: after running ``builder_config.init``, the config are loaded already. If users want to update config, after the updated config file is saved and closed, the config will take effect after reload config space by running ``builder_config.init`` again.
## Step2: Config Sampling From Prior Distribution
To learn the relationship between configurations and latency, we need to generate a training set (i.e., variously configured kernels and the latencies) for regression. While it's unfeasible to sample and measure all the configurations for all kernels, a direct method is random sampling.
The first step is sampling configuration values from the prior distribution, which is inferred from the existing models. Based on our kernel model, there are generally 6 configuration values, including height and width (`"HW"`), input channel (`"CIN"`), output channel (`"COUT"`), kernel size (`"KERNEL_SIZE"`), strides (`"STRIDES"`), and kernel size for pooling layer (`"POOL_STRIDES"`). We sampling the configuration based on the prior distribution and adapt the value to common valid values. That is, height and weight are verified to value from `[1, 3, 7, 14, 28, 56, 112, 224]`, kernel size to `[1, 3, 5, 7]`, strides to `[1, 2, 4]`, and kernel size for pooling layer to `[2, 3]`. We stored the prior knowledge of existing models as csv files in `nn_meter/builder/kernel_predictor_builder/data_sampler/prior_config_lib/`.
## Step 2: Generate and Profile Kernel Model by Configs
## Step 3: Generate and Profile Kernel Model by Configs
The second step is generating and profiling kernel model by configurations. nn-Meter supports both implementation of Tensorflow and PyTorch kernels. Users could switch the kernel implementation between Tensorflow and PyTorch by editing configuration `IMPLEMENT` in `<workspace-path>/configs/predictorbuild_config.yaml`. Here we use Tensorflow implementation and `"tflite_cpu"` backend as an example.
@ -65,7 +93,9 @@ kernel_data = sample_and_profile_kernel_data(kernel_type, sample_num=sample_num,
backend=backend, sampling_mode='prior', mark=mark)
```
The generated models are saved in `<workspace-path>/predictor_build/models`, and the configuration information and profiled results are dumped in json file to `<workspace-path>/predictor_build/results/<kernel_type>.json` and `<workspace-path>/predictor_build/results/profiled_<kernel_type>.json`.
The generated models are saved in `<workspace-path>/predictor_build/kernels`, and the configuration information and profiled results are dumped in json file to `<workspace-path>/predictor_build/results/<kernel_type>.json` and `<workspace-path>/predictor_build/results/profiled_<kernel_type>.json`.
Note: sometimes the number of sampling kernel data is smaller than the value of `sample_num`. It is natural since nn-Meter will remove duplicates sample when generating kernel data.
Note: sometimes the number of sampling kernel data is smaller than the value of `sample_num`. It is natural since nn-Meter will remove duplicates sample when generating kernel data.
@ -98,7 +128,7 @@ profiled_results = profile_models(backend, models, mode='predbuild', have_conver
Note: for kernels related to `conv` or `dwconv`, our experiment results have shown that all kernels containing one `conv` layer have almost the same latency results, as `conv` layer has dominant latency. For example, `conv-bn-relu` has almost the same latency as `conv-block`. Same observation was found for `dwconv` related kernels. Therefore in nn-Meter, all `conv` related kernels shares the same kernel predictor, so does `dwconv` related kernels.
## Step 3: Initialize Kernel Latency Predictor
## Step 4: Initialize Kernel Latency Predictor
After preparing the training data, we construct a random forest regression model as the kernel latency predictor. Here is an example:
@ -154,7 +184,7 @@ predictor, data = build_initial_predictor_by_data(
)
```
## Step 4: Adaptive Data Sampling
## Step 5: Adaptive Data Sampling
In the paper of nn-Meter, we have observe that the configuration of kernel size (`KERNEL_SIZE`), height and width (`HW`), input channel (`CIN`), and output channel (`COUT`) show the non-linearity pattern on our measured devices. Instead, `HW` and `COUT` exhibit the staircase pattern, in which Conv with two different `HW` / `COUT` may have the same latency. These non-linearities reflect the complexities in hardware optimizations.
@ -185,10 +215,10 @@ backend = "tflite_cpu"
error_threshold = 0.1
predictor, data = build_adaptive_predictor_by_data(
kernel_type, data, backend, finegrained_sample_num=5
kernel_type, kernel_data, backend, finegrained_sample_num=5
)
```
In the method `build_adaptive_predictor_by_data`, the parameter `data` indicates all training and testing data for current predictor training. The value of `data` could either be an instance of Dict generated by `build_initial_predictor_by_data` or `build_adaptive_predictor_by_data`, or be a instance of Tuple such as:
In the method `build_adaptive_predictor_by_data`, the parameter `kernel_data` indicates all training and testing data for current predictor training. The value of `kernel_data` could either be an instance of Dict generated by `build_initial_predictor_by_data` or `build_adaptive_predictor_by_data`, or be a instance of Tuple such as:
```python
config_json_file = [
@ -237,7 +267,108 @@ from nn_meter.builder import build_latency_predictor
build_latency_predictor(backend="tflite_cpu")
```
# Build predictor for customized kernel
# Kernel Data Format
## Structure of Kernel Data
In the process to build kernel latency predictor, a series of kernel data will be sampled, generated, and profiled to build the training dataset. One piece of complete kernel data consists of two parts, the configuration information and the profiled results (in default is the latency value). The configuration information and profiled results are dumped in json file to `<workspace-path>/predictor_build/results/<kernel_type>_<mark>.json` and `<workspace-path>/predictor_build/results/profiled_<kernel_type>.json` respectively. In each data piece, `"model"` points to its directory to the path of this kernels' `Keras` model, `"shapes"` indicates the input shape of the tensor to test, and `"latency"` reports the profiled results after running `profile_models`. The ids of kernel data are randomly generated and consists of 6 capital letters.
This is a json dump of the configuration information of generated kernels, which we call it config json file:
```json
"conv-bn-relu": {
"YB2F4N": {
"model": "<workspace-path>/predictor_build/kernels/conv-bn-relu_prior_YB2F4N",
"shapes": [
[
13,
13,
212
]
],
"config": {
"HW": 13,
"CIN": 212,
"COUT": 176,
"KERNEL_SIZE": 1,
"STRIDES": 1
}
}
}
```
After running and profiling the kernels, the `"latency"` attribute appears in `<workspace-path>/predictor_build/results/profiled_<kernel_type>.json`, which we call it latency json file:
```json
"conv-bn-relu": {
"YB2F4N": {
"latency": "37.53 +- 0.314"
}
}
```
Note: If the parameter `DETAIL` is `TRUE` in `<workspace-path>/configs/predictorbuild_config.yaml`, the configuration information will also be dumped in `<workspace-path>/predictor_build/results/profiled_<kernel_type>.json`.
Config json file and latency json file formed the training data of kernel latency predictor. The kernel data should be defined as follows:
```python
config_json_file = [
f'{workspace}/predictor_build/results/{kernel_type}_prior.json',
f'{workspace}/predictor_build/results/{kernel_type}_finegrained1.json',
f'{workspace}/predictor_build/results/{kernel_type}_finegrained2.json'
]
latency_json_file = [
f'{workspace}/predictor_build/results/profiled_{kernel_type}.json'
]
kernel_data = (config_json_file, latency_json_file)
```
and called as follows:
```python
# build initial latency predictor by kernel_data
predictor, acc10, error_configs = build_predictor_by_data(
kernel_type, kernel_data, backend, error_threshold=error_threshold, mark="prior",
save_path=os.path.join(workspace, "predictor_build", "results")
)
```
or:
```python
# build adaptive latency predictor by kernel_data
predictor, data = build_adaptive_predictor_by_data(
kernel_type, kernel_data, backend, finegrained_sample_num=5
)
```
## Convert Kernel Data to CSV
nn-Meter provides method to convert kernel data json files to CSV. Here is an example:
```python
from nn_meter.builder.kernel_predictor_builder.predictor_builder.utils import collect_kernel_data
from nn_meter.builder.kernel_predictor_builder.predictor_builder.extract_feature import get_feature_parser, get_data_by_profiled_results
kernel_type = 'conv-bn-relu'
# define kernel data
config_json_file = [
f'{workspace}/predictor_build/results/{kernel_type}_prior.json',
f'{workspace}/predictor_build/results/{kernel_type}_finegrained1.json',
f'{workspace}/predictor_build/results/{kernel_type}_finegrained2.json'
]
latency_json_file = [
f'{workspace}/predictor_build/results/profiled_{kernel_type}.json'
]
kernel_data = (config_json_file, latency_json_file)
# read kernel data and extract features
kernel_data = collect_kernel_data(kernel_data)
feature_parser = get_feature_parser(kernel_type) # define the feature to extract
data = get_data_by_profiled_results(kernel_type, feature_parser, kernel_data,
                                    save_path="path/to/csv/test.csv")
```
# Build Predictor for Customized Kernel
If users want to add new kernels to profile latency and build predictor, here are several steps to prepare and register new kernels.
@ -245,7 +376,7 @@ If users want to add new kernels to profile latency and build predictor, here ar
### Step 1: Prepare the Customized Kernel Class
nn-Meter provide API for users to customize their own kernel block. In nn-Meter, each kernel is implemented by inheriting a base class named `nn_meter.builder.nn_generator.BaseBlock`. The kernel block has a input parameter `config` to feed configuration params for the kernel. There are two attributes should be claimed, including `input_shape` and `input_tensor_shape`, as well as one method named `get_model()`. nn-Meter support both Tensorflow and PyTorch implementation for the kernel model. Users could switch the kernel implementation between Tensorflow and PyTorch by editing configuration `IMPLEMENT` in `<workspace-path>/configs/predictorbuild_config.yaml`. Here we use Tensorflow implementation as an example.
nn-Meter provide API for users to customize their own kernel block. In nn-Meter, each kernel is implemented by inheriting a base class named `nn_meter.builder.nn_modules.BaseBlock`. The kernel block has a input parameter `config` to feed configuration params for the kernel. There are two attributes should be claimed, including `input_shape` and `input_tensor_shape`, as well as one method named `get_model()`. nn-Meter support both Tensorflow and PyTorch implementation for the kernel model. Users could switch the kernel implementation between Tensorflow and PyTorch by editing configuration `IMPLEMENT` in `<workspace-path>/configs/predictorbuild_config.yaml`. Here we use Tensorflow implementation as an example.
- `input_shape` defines the dimension of one model input shape without batch size. Generally, when the input shape is 3D, `input_shape` should be`[config["HW"], config["HW"], config["CIN"]]`, and when the input shape is 1D, `input_shape` should be`[config["CIN"]]`.
@ -257,7 +388,7 @@ Users could refer to the following example to learn how to write a kernel class.
``` python
import tensorflow.keras as keras
from nn_meter.builder.nn_generator import BaseBlock
from nn_meter.builder.nn_modules import BaseBlock
class MyKernel(BaseBlock):
''' This kernel is built by Conv, BN, and Relu layer, which is the same as the builtin `conv-bn-relu` block.

Просмотреть файл

@ -285,7 +285,7 @@ If users want to add new operators into basic test cases, here are several steps
### Step 1: Prepare the Customized Operator Class
nn-Meter provide API for users to customize their own operator. In nn-Meter, each operator is implemented by inheriting a base class named `nn_meter.builder.nn_generator.BaseOperator`. The class has two input parameters, i.e., `input_shape` and `config`. `input_shape` is a list showing the dimension of the input tensor (the batch dimension should not be included), and `config` can be used to feed configuration params for the operator. There are following methods in this base class:
nn-Meter provide API for users to customize their own operator. In nn-Meter, each operator is implemented by inheriting a base class named `nn_meter.builder.nn_modules.BaseOperator`. The class has two input parameters, i.e., `input_shape` and `config`. `input_shape` is a list showing the dimension of the input tensor (the batch dimension should not be included), and `config` can be used to feed configuration params for the operator. There are following methods in this base class:
- `get_model`: Return the model function of the operator. Users need to modify this **all the time**.
@ -301,7 +301,7 @@ The first example is Conv2d operator. The operator simply applying APIs from `te
``` python
import tensorflow.keras as keras
from nn_meter.builder.nn_generator import BaseOperator
from nn_meter.builder.nn_modules import BaseOperator
class Conv(BaseOperator):
def get_model(self):
@ -363,7 +363,7 @@ nn-Meter requires users to gather all code of operator in a package with a fixed
The interface of customized operator class are stored in `./customized_operator/operator_script.py`. In this demo, the content of `operator_script.py` includes:
``` python
from nn_meter.builder.nn_generator import BaseOperator
from nn_meter.builder.nn_modules import BaseOperator
from tensorflow import keras
def Op1(BaseOperator):

Просмотреть файл

@ -85,7 +85,7 @@ class BaseTestCase:
self.kernel_size = config['KERNEL_SIZE']
self.cout = config['COUT']
self.padding = config['PADDING']
self.workspace_path = os.path.join(config['WORKSPACE'], 'models')
self.workspace_path = os.path.join(config['WORKSPACE'], 'testcases')
os.makedirs(self.workspace_path, exist_ok=True)
def _model_block(self):

Просмотреть файл

@ -52,9 +52,9 @@ def get_operator_by_name(operator_name, input_shape, config = None, implement =
elif operator_name in __BUILTIN_OPERATORS__:
operator_module_name = __BUILTIN_OPERATORS__[operator_name]
if implement == 'tensorflow':
from nn_meter.builder.nn_generator.tf_networks import operators
from nn_meter.builder.nn_modules.tf_networks import operators
elif implement == 'torch':
from nn_meter.builder.nn_generator.torch_networks import operators
from nn_meter.builder.nn_modules.torch_networks import operators
else:
raise NotImplementedError('You must choose one implementation of kernel from "tensorflow" or "pytorch"')
operator_module = operators
@ -96,10 +96,10 @@ def get_special_testcases_by_name(testcase, implement=None):
def generate_models_for_testcase(op1, op2, input_shape, config, implement):
if implement == 'tensorflow':
from .build_tf_models import SingleOpModel, TwoOpModel
from nn_meter.builder.nn_generator.tf_networks.utils import get_inputs_by_shapes
from nn_meter.builder.nn_modules.tf_networks.utils import get_inputs_by_shapes
elif implement == 'torch':
from .build_torch_models import SingleOpModel, TwoOpModel
from nn_meter.builder.nn_generator.torch_networks.utils import get_inputs_by_shapes
from nn_meter.builder.nn_modules.torch_networks.utils import get_inputs_by_shapes
else:
raise NotImplementedError('You must choose one implementation of kernel from "tensorflow" or "pytorch"')
@ -124,10 +124,10 @@ def generate_models_for_testcase(op1, op2, input_shape, config, implement):
def generate_single_model(op, input_shape, config, implement):
if implement == 'tensorflow':
from .build_tf_models import SingleOpModel
from nn_meter.builder.nn_generator.tf_networks.utils import get_inputs_by_shapes
from nn_meter.builder.nn_modules.tf_networks.utils import get_inputs_by_shapes
elif implement == 'torch':
from .build_torch_models import SingleOpModel
from nn_meter.builder.nn_generator.torch_networks.utils import get_inputs_by_shapes
from nn_meter.builder.nn_modules.torch_networks.utils import get_inputs_by_shapes
else:
raise NotImplementedError('You must choose one implementation of kernel from "tensorflow" or "pytorch"')
@ -143,13 +143,14 @@ def generate_single_model(op, input_shape, config, implement):
def save_model(model, model_path, implement):
if implement == 'tensorflow':
from tensorflow import keras
from nn_meter.builder.nn_generator.tf_networks.utils import get_tensor_by_shapes
from nn_meter.builder.nn_modules.tf_networks.utils import get_tensor_by_shapes
model['model'](get_tensor_by_shapes(model['shapes']))
keras.models.save_model(model['model'], model_path)
return model_path
elif implement == 'torch':
import torch
from nn_meter.builder.nn_generator.torch_networks.utils import get_inputs_by_shapes
from nn_meter.builder.nn_modules.torch_networks.utils import get_inputs_by_shapes
torch.onnx.export(
model['model'],
get_inputs_by_shapes(model['shapes']),

Просмотреть файл

@ -16,7 +16,7 @@ class KernelGenerator:
self.kernel_type = kernel_type
self.sample_num = sample_num
self.workspace_path = builder_config.get('WORKSPACE', 'predbuild')
self.case_save_path = os.path.join(self.workspace_path, 'models')
self.case_save_path = os.path.join(self.workspace_path, 'kernels')
self.kernel_info = {kernel_type: {}}
self.kernels = self.kernel_info[self.kernel_type]
self.implement = builder_config.get('IMPLEMENT', 'predbuild')
@ -37,6 +37,8 @@ class KernelGenerator:
"""
kernel_type = self.kernel_type
logging.info(f"building kernel for {kernel_type}...")
count = 0
error_save_path = os.path.join(self.workspace_path, 'results', 'generate_error.log')
for id, value in self.kernels.items():
model_path = os.path.join(self.case_save_path, ("_".join([kernel_type, self.mark, id]) + self.model_suffix))
kernel_cfg = value['config']
@ -50,9 +52,19 @@ class KernelGenerator:
'shapes': input_tensor_shape,
'config': config
}
except:
pass
count += 1
except Exception as e:
open(os.path.join(self.workspace_path, "results", "generate_error.log"), 'a').write(f"{id}: {e}\n")
# save information to json file in incrementally mode
info_save_path = os.path.join(self.workspace_path, "results", f"{kernel_type}_{self.mark}.json")
new_kernels_info = merge_info(new_info=self.kernel_info, info_save_path=info_save_path)
os.makedirs(os.path.dirname(info_save_path), exist_ok=True)
with open(info_save_path, 'w') as fp:
json.dump(new_kernels_info, fp, indent=4)
logging.keyinfo(f"Generate {len(self.kernels)} kernels and save info to {info_save_path} " \
f"Failed information are saved in {error_save_path} (if any).")
def run(self, sampling_mode = 'prior', configs = None):
""" sample N configurations for target kernel, generate tensorflow keras model files.
@ -89,13 +101,4 @@ def generate_config_sample(kernel_type, sample_num, mark = '', sampling_mode = '
generator = KernelGenerator(kernel_type, sample_num, mark=mark)
kernels_info = generator.run(sampling_mode=sampling_mode, configs=configs)
# save information to json file in incrementally mode
workspace_path = builder_config.get('WORKSPACE', "predbuild")
info_save_path = os.path.join(workspace_path, "results", f"{kernel_type}_{mark}.json")
new_kernels_info = merge_info(new_info=kernels_info, info_save_path=info_save_path)
os.makedirs(os.path.dirname(info_save_path), exist_ok=True)
with open(info_save_path, 'w') as fp:
json.dump(new_kernels_info, fp, indent=4)
logging.keyinfo(f"Save the kernel model information to {info_save_path}")
return kernels_info

Просмотреть файл

@ -69,9 +69,9 @@ def generate_model_for_kernel(kernel_type, config, save_path, implement='tensorf
elif kernel_type in __BUILTIN_KERNELS__:
kernel_name = __BUILTIN_KERNELS__[kernel_type][0]
if implement == 'tensorflow':
from nn_meter.builder.nn_generator.tf_networks import blocks
from nn_meter.builder.nn_modules.tf_networks import blocks
elif implement == 'torch':
from nn_meter.builder.nn_generator.torch_networks import blocks
from nn_meter.builder.nn_modules.torch_networks import blocks
else:
raise NotImplementedError('You must choose one implementation of kernel from "tensorflow" or "pytorch"')
kernel_module = blocks

Просмотреть файл

@ -143,7 +143,7 @@ def get_data_by_profiled_results(kernel_type, feature_parser, cfgs_path, labs_pa
}
if labs_path == None, it means latency (or other label) information are also included in cfgs_path.
save_path (str): the path to save the feature and latency information. If save_path is None, the data will not be saved.
save_path (str): the path to save the feature and latency information, including file name. If save_path is None, the data will not be saved.
predict_label (str): the predicting label to build kernel predictor
'''

Просмотреть файл

@ -32,13 +32,15 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
save_name = "converted_results.json"
workspace_path = builder_config.get('WORKSPACE', mode)
model_save_path = os.path.join(workspace_path, 'models')
model_save_path = os.path.join(workspace_path, 'testcases' if mode == 'ruletest' else 'kernels')
os.makedirs(model_save_path, exist_ok=True)
info_save_path = os.path.join(workspace_path, "results")
os.makedirs(info_save_path, exist_ok=True)
res_save_path = os.path.join(workspace_path, "results")
os.makedirs(res_save_path, exist_ok=True)
# convert models
count = 0
info_save_path = os.path.join(res_save_path, save_name)
error_save_path = os.path.join(res_save_path, "convert_error.log")
for module in models.values():
for id, model in module.items():
if broken_point_mode and 'converted_model' in model:
@ -47,25 +49,22 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
model_path = model['model']
converted_model = backend.convert_model(model_path, model_save_path, model['shapes'])
model['converted_model'] = converted_model
count += 1
except Exception as e:
open(os.path.join(info_save_path, "convert_error.log"), 'a').write(f"{id}: {e}\n")
open(error_save_path, 'a').write(f"{id}: {e}\n")
# save information to json file for per 50 models
count += 1
if count % 50 == 0:
with open(os.path.join(info_save_path, save_name), 'w') as fp:
with open(info_save_path, 'w') as fp:
json.dump(models, fp, indent=4)
logging.keyinfo(f"{count} models complete. Still converting... Save the intermediate results to {os.path.join(info_save_path, save_name)}.")
with open(os.path.join(info_save_path, save_name), 'w') as fp:
json.dump(models, fp, indent=4)
logging.keyinfo(f"Complete convert all {count} models. Save the intermediate results to {os.path.join(info_save_path, save_name)}.")
logging.keyinfo(f"{count} models complete. Still converting... Save the intermediate results to {info_save_path} ")
# save information to json file
with open(os.path.join(info_save_path, save_name), 'w') as fp:
with open(info_save_path, 'w') as fp:
json.dump(models, fp, indent=4)
logging.keyinfo(f"Save the converted models information to {os.path.join(info_save_path, save_name)}")
logging.keyinfo(f"Complete converting all {count} models. Save the results to {info_save_path} " \
f"Failed information are saved in {error_save_path} (if any)")
return models
@ -99,16 +98,17 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
models = json.load(fp)
workspace_path = builder_config.get('WORKSPACE', mode)
model_save_path = os.path.join(workspace_path, 'models')
model_save_path = os.path.join(workspace_path, 'testcases' if mode == 'ruletest' else 'kernels')
os.makedirs(model_save_path, exist_ok=True)
info_save_path = os.path.join(workspace_path, "results")
os.makedirs(info_save_path, exist_ok=True)
res_save_path = os.path.join(workspace_path, "results")
os.makedirs(res_save_path, exist_ok=True)
info_save_path = os.path.join(res_save_path, save_name)
# in broken point model, if the output file `<workspace>/<mode-folder>/results/<save-name>` exists,
# load the existing latency and skip these model in profiling
if broken_point_mode and os.path.isfile(os.path.join(info_save_path, save_name)):
if broken_point_mode and os.path.isfile(info_save_path):
from nn_meter.builder.backend_meta.utils import read_profiled_results
with open(os.path.join(info_save_path, save_name), 'r') as fp:
with open(info_save_path, 'r') as fp:
profiled_models = read_profiled_results(json.load(fp))
for module_key, module in models.items():
if module_key not in profiled_models:
@ -118,7 +118,8 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
model.update(profiled_models[module_key][id])
# profile models and get metric results
count = 0
count = 0
error_save_path = os.path.join(res_save_path, "profile_error.log")
detail = builder_config.get('DETAIL', mode)
save_name = save_name or "profiled_results.json"
logging.info("Profiling ...")
@ -135,7 +136,7 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
time.sleep(0.2)
count += 1
except Exception as e:
open(os.path.join(info_save_path, "profile_error.log"), 'a').write(f"{id}: {e}\n")
open(error_save_path, 'a').write(f"{id}: {e}\n")
else: # the models have not been converted
try:
model_path = model['model']
@ -145,16 +146,17 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
time.sleep(0.2)
count += 1
except Exception as e:
open(os.path.join(info_save_path, "profile_error.log"), 'a').write(f"{id}: {e}\n")
open(error_save_path, 'a').write(f"{id}: {e}\n")
# save information to json file for per 50 models
if count > 0 and count % log_frequency == 0:
save_profiled_results(models, os.path.join(info_save_path, save_name), detail, metrics)
logging.keyinfo(f"{count} models complete. Still profiling... Save the intermediate results to {os.path.join(info_save_path, save_name)}.")
save_profiled_results(models, info_save_path, detail, metrics)
logging.keyinfo(f"{count} models complete. Still profiling... Save the intermediate results to {info_save_path} ")
# save information to json file
save_profiled_results(models, os.path.join(info_save_path, save_name), detail, metrics)
logging.keyinfo(f"All {count} models profiling complete. Save all success profiled results to {os.path.join(info_save_path, save_name)}.")
save_profiled_results(models, info_save_path, detail, metrics)
logging.keyinfo(f"All {count} models profiling complete. Save all success profiled results to {info_save_path} " \
f"Failed information are saved in {error_save_path} (if any)")
return models
@ -192,12 +194,12 @@ def build_predictor_for_kernel(kernel_type, backend, init_sample_num = 1000, fin
init_sample_num (int, optional): the data size for predictor initialization. Defaults to 1000.
finegrained_sample_num (int, optional): the data size for adaptive sampling. For each data with error higher than
error_threshold, #finegrained_sample_num data will be generated based the the large error data. Defaults to 10.
error_threshold, number of `finegrained_sample_num` data will be generated based the the large error data. Defaults to 10.
iteration (int, optional): the iteration for sampling and training. Initial sampling is regarded as iteration 1,
thus `iteration == 2` means one iteration for adaptive sampling. Defaults to 5.
iteration (int, optional): the iteration for sampling and training. Predictor training based on initial sampling is regarded as
iteration 1, thus `iteration == 2` means one iteration for adaptive sampling. Defaults to 5.
error_threshold (float, optional): the threshold of large error. Defaults to 0.2.
error_threshold (float, optional): the threshold of large error. Defaults to 0.1.
predict_label (str): the predicting label to build kernel predictor. Defaults to "latency"

Просмотреть файл

@ -111,35 +111,23 @@ class AvgPool(BaseOperator):
class SE(BaseOperator):
def get_model(self):
class SE(keras.layers.Layer):
def __init__(self, input_shape):
def __init__(self, num_channels, se_ratio=0.25):
super().__init__()
self.in_shape = input_shape
self.conv1 = keras.layers.Conv2D(
filters=self.in_shape[-1] // 4,
kernel_size=[1, 1],
strides=[1, 1],
padding="same",
)
self.conv2 = keras.layers.Conv2D(
filters=self.in_shape[-1],
kernel_size=[1, 1],
strides=[1, 1],
padding="same",
)
self.pool = keras.layers.GlobalAveragePooling2D(keepdims=True)
self.squeeze = keras.layers.Conv2D(filters=int(num_channels * se_ratio), kernel_size=1, padding='same')
self.relu = keras.layers.ReLU()
self.excite = keras.layers.Conv2D(filters=num_channels, kernel_size=1, padding='same')
self.hswish = Hswish().get_model()
def call(self, inputs):
x = tf.nn.avg_pool(
inputs,
ksize=[1] + self.in_shape[0:2] + [1],
strides=[1, 1, 1, 1],
padding="VALID",
)
x = self.conv1(x)
x = tf.nn.relu(x)
x = self.conv2(x)
x = tf.nn.relu6(tf.math.add(x, 3)) * 0.16667
return x * inputs
return SE(self.input_shape)
def call(self, x):
x0 = x
x = self.pool(x)
x = self.squeeze(x)
x = self.relu(x)
x = self.excite(x)
x = self.hswish(x)
return x * x0
return SE(self.input_shape[-1])
class FC(BaseOperator):

Просмотреть файл

@ -10,6 +10,12 @@ logging = logging.getLogger("nn-Meter")
class TorchBlock(BaseBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
def test_block(self):
input_data = torch.randn(1, self.config["CIN"], self.config["HW"], self.config["HW"])
output_data = self.get_model()(input_data)
@ -32,10 +38,7 @@ class TorchBlock(BaseBlock):
class ConvBnRelu(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -65,10 +68,7 @@ class ConvBnRelu(TorchBlock):
class ConvBnRelu6(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -98,10 +98,7 @@ class ConvBnRelu6(TorchBlock):
class ConvBn(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -126,10 +123,7 @@ class ConvBn(TorchBlock):
class ConvRelu(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -154,10 +148,7 @@ class ConvRelu(TorchBlock):
class ConvRelu6(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -182,10 +173,7 @@ class ConvRelu6(TorchBlock):
class ConvHswish(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -210,10 +198,7 @@ class ConvHswish(TorchBlock):
class ConvBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op = conv_op.get_model()
@ -232,10 +217,7 @@ class ConvBlock(TorchBlock):
class ConvBnHswish(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -265,10 +247,7 @@ class ConvBnHswish(TorchBlock):
class ConvBnReluMaxPool(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()
@ -303,10 +282,7 @@ class ConvBnReluMaxPool(TorchBlock):
class DwConvBn(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()
@ -331,10 +307,7 @@ class DwConvBn(TorchBlock):
class DwConvRelu(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()
@ -359,10 +332,7 @@ class DwConvRelu(TorchBlock):
class DwConvRelu6(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()
@ -387,10 +357,7 @@ class DwConvRelu6(TorchBlock):
class DwConvBnRelu(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()
@ -420,10 +387,7 @@ class DwConvBnRelu(TorchBlock):
class DwConvBnRelu6(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()
@ -453,10 +417,7 @@ class DwConvBnRelu6(TorchBlock):
class DwConvBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op = dwconv_op.get_model()
@ -475,10 +436,7 @@ class DwConvBlock(TorchBlock):
class ConvBnHswish(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()
@ -508,10 +466,7 @@ class ConvBnHswish(TorchBlock):
class MaxPoolBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
maxpool_op = MaxPool(self.input_shape, config)
self.maxpool_op = maxpool_op.get_model()
@ -530,10 +485,7 @@ class MaxPoolBlock(TorchBlock):
class AvgPoolBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
avgpool_op = AvgPool(self.input_shape, config)
self.avgpool_op = avgpool_op.get_model()
@ -598,10 +550,7 @@ class ConcatBlock(TorchBlock):
class SplitBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
split_op = Split(self.input_shape, config)
self.split_op = split_op.get_model()
@ -620,10 +569,7 @@ class SplitBlock(TorchBlock):
class ChannelShuffle(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
def get_model(self):
class Model(nn.Module):
@ -642,10 +588,7 @@ class ChannelShuffle(TorchBlock):
class SEBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
se_op = SE(self.input_shape, config)
self.se_op = se_op.get_model()
@ -686,10 +629,7 @@ class GlobalAvgPoolBlock(TorchBlock):
class BnRelu(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
bn_op = BN(self.input_shape, config)
self.bn_op, out_shape = bn_op.get_model(), bn_op.get_output_shape()
@ -714,10 +654,7 @@ class BnRelu(TorchBlock):
class BnBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
bn_op = BN(self.input_shape, config)
self.bn_op = bn_op.get_model()
@ -736,10 +673,7 @@ class BnBlock(TorchBlock):
class HswishBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
hswish_op = Hswish(self.input_shape, config)
self.hswish_op = hswish_op.get_model()
@ -758,10 +692,7 @@ class HswishBlock(TorchBlock):
class ReluBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
relu_op = Relu(self.input_shape, config)
self.relu_op = relu_op.get_model()
@ -780,10 +711,7 @@ class ReluBlock(TorchBlock):
class AddRelu(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
add_op = Add(self.input_shape, config)
self.add_op, out_shape = add_op.get_model(), add_op.get_output_shape()
@ -808,10 +736,7 @@ class AddRelu(TorchBlock):
class AddBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
self.input_shape = [config["CIN"], config["HW"], config["HW"]]
self.input_tensor_shape = [self.input_shape]
self.batch_size = batch_size
super().__init__(config, batch_size)
add_op = Add(self.input_shape, config)
self.add_op = add_op.get_model()

Просмотреть файл

@ -103,24 +103,29 @@ class AvgPool(BaseOperator):
class SE(BaseOperator):
def get_model(self):
from nn_meter.builder.utils import make_divisible
class SE(nn.Module):
def __init__(self, input_shape):
def __init__(self, num_channels, se_ratio=0.25):
super().__init__()
cin = input_shape[0]
self.avgpool = nn.AdaptiveAvgPool2d([1, 1])
self.conv1 = nn.Conv2d(cin, cin // 4, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(cin // 4, cin, kernel_size=1, stride=1, padding=0)
mid_channels = int(num_channels * se_ratio)
self.squeeze = nn.Conv2d(num_channels, mid_channels, kernel_size=1, padding=0)
self.relu = nn.ReLU()
self.excite = nn.Conv2d(mid_channels, num_channels, kernel_size=1, padding=0)
self.hswish = nn.Hardswish()
def forward(self, inputs):
x = self.avgpool(inputs)
x = self.conv1(x)
def _scale(self, x):
x = x.mean(3, keepdim=True).mean(2, keepdim=True)
x = self.squeeze(x)
x = self.relu(x)
x = self.conv2(x)
x = self.excite(x)
x = self.hswish(x)
return x * inputs
return SE(self.input_shape)
return x
def forward(self, x):
scale = self._scale(x)
return scale * x
return SE(self.input_shape[0])
class FC(BaseOperator):
def get_model(self):

Просмотреть файл

@ -1,5 +1,5 @@
import tensorflow.keras as keras
from nn_meter.builder.nn_generator import BaseBlock
from nn_meter.builder.nn_modules import BaseBlock
from nn_meter.builder.kernel_predictor_builder import BaseFeatureParser, BaseConfigSampler

Просмотреть файл

@ -1,5 +1,5 @@
import tensorflow.keras as keras
from nn_meter.builder.nn_generator import BaseOperator
from nn_meter.builder.nn_modules import BaseOperator
class MyOp(BaseOperator):
def get_model(self):

Просмотреть файл

@ -19,9 +19,12 @@ if __name__ == '__main__':
"CIN3": 12,
"CIN4": 12
}
from nn_meter.builder.nn_generator.tf_networks import blocks
from nn_meter.builder.nn_modules.tf_networks import blocks
for kernel in kernels:
getattr(blocks, kernel)(config).test_block()
from nn_meter.builder.nn_generator.torch_networks import blocks
from nn_meter.builder.nn_modules.torch_networks import blocks
for kernel in kernels:
getattr(blocks, kernel)(config).test_block()
from nn_meter.builder.nn_modules.torch_networks import blocks
for kernel in kernels:
getattr(blocks, kernel)(config).test_block()