* Feature: add gnn_dataloader and give GNN example
This commit is contained in:
Wenxuan Liu 2021-09-26 14:48:34 +08:00 коммит произвёл GitHub
Родитель 1f8c816bca
Коммит 1d141ced16
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 554 добавлений и 2 удалений

Просмотреть файл

@ -221,6 +221,14 @@ exp.run(exp_config, port)
In `exp_config`, `dummy_input` is required for tracing shape info.
## Bench Dataset
[BRP-NAS](https://arxiv.org/abs/2007.08668v2) proposes an end-to-end latency predictor which consists of a GCN. Their GCN predictor demonstrates significant improvement over the layer-wise predictor on [NAS-Bench-201](https://arxiv.org/abs/2001.00326). While on our bench dataset, the preformance of BRP-NAS is consistently poor. As discussed in our paper, the reason is the model graph difference between training and testing set. GNN learns the representation of model graphs. Although the models in our bench dataset have largely overlapped operator types, the operator configurations, edges, and model latency ranges are different.
To better deal with the problems above, we give a GNN example with graph representation improved. We build the data structure of our model in `GNNDataset` and `GNNDataloader` from `nn_meter/dataset/gnn_dataloader.py` to convert the Dataset in `.jsonl` format into GNN required Dataset and Dataloader. And for specific use cases, please refer to `examples/gnn_for_bench_dataset.ipynb`.
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a

Просмотреть файл

@ -0,0 +1,273 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Latency Dataset - GNN Model\r\n",
"\r\n",
"Considering the dataset is encoded in a graph format, here is an example of using GNN to predict the model latency with the bench dataset. `GNNDataset` and `GNNDataloader` in `nn_meter/dataset/gnn_dataloader.py` build the model structure of the Dataset in `.jsonl` format into our required Dataset and Dataloader. We will first build our GNN model, which is constructed based on GraphSAGE, and maxpooling is selected as out pooling method. Next, we will start training after the data is loaded.\r\n",
"\r\n",
"Let's start our journey!"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Step 1: Build our GraphSAGE Model\n",
"\n",
"We built our model with the help of DGL library."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import torch\r\n",
"import torch.nn as nn\r\n",
"from torch.nn.modules.module import Module\r\n",
"\r\n",
"from dgl.nn.pytorch.glob import MaxPooling\r\n",
"import dgl.nn as dglnn\r\n",
"from torch.optim.lr_scheduler import CosineAnnealingLR\r\n",
"\r\n",
"\r\n",
"class GNN(Module):\r\n",
" def __init__(self, \r\n",
" num_features=0, \r\n",
" num_layers=2,\r\n",
" num_hidden=32,\r\n",
" dropout_ratio=0):\r\n",
"\r\n",
" super(GNN, self).__init__()\r\n",
" self.nfeat = num_features\r\n",
" self.nlayer = num_layers\r\n",
" self.nhid = num_hidden\r\n",
" self.dropout_ratio = dropout_ratio\r\n",
" self.gc = nn.ModuleList([dglnn.SAGEConv(self.nfeat if i==0 else self.nhid, self.nhid, 'pool') for i in range(self.nlayer)])\r\n",
" self.bn = nn.ModuleList([nn.LayerNorm(self.nhid) for i in range(self.nlayer)])\r\n",
" self.relu = nn.ModuleList([nn.ReLU() for i in range(self.nlayer)])\r\n",
" self.pooling = MaxPooling()\r\n",
" self.fc = nn.Linear(self.nhid, 1)\r\n",
" self.fc1 = nn.Linear(self.nhid, self.nhid)\r\n",
" self.dropout = nn.ModuleList([nn.Dropout(self.dropout_ratio) for i in range(self.nlayer)])\r\n",
"\r\n",
" def forward_single_model(self, g, features):\r\n",
" x = self.relu[0](self.bn[0](self.gc[0](g, features)))\r\n",
" x = self.dropout[0](x)\r\n",
" for i in range(1,self.nlayer):\r\n",
" x = self.relu[i](self.bn[i](self.gc[i](g, x)))\r\n",
" x = self.dropout[i](x)\r\n",
" return x\r\n",
"\r\n",
" def forward(self, g, features):\r\n",
" x = self.forward_single_model(g, features)\r\n",
" with g.local_scope():\r\n",
" g.ndata['h'] = x\r\n",
" x = self.pooling(g, x)\r\n",
" x = self.fc1(x)\r\n",
" return self.fc(x)"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Using backend: pytorch\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Step 2: Loading Data.\n",
"\n",
"Next, we will finish loading the data and learn about the size of the Training and Testing datasets."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import os\r\n",
"from nn_meter.dataset import gnn_dataloader\r\n",
"\r\n",
"target_device = \"cortexA76cpu_tflite21\"\r\n",
"\r\n",
"print(\"Processing Training Set.\")\r\n",
"train_set = gnn_dataloader.GNNDataset(train=True, device=target_device) \r\n",
"print(\"Processing Testing Set.\")\r\n",
"test_set = gnn_dataloader.GNNDataset(train=False, device=target_device)\r\n",
"\r\n",
"train_loader = gnn_dataloader.GNNDataloader(train_set, batchsize=1 , shuffle=True)\r\n",
"test_loader = gnn_dataloader.GNNDataloader(test_set, batchsize=1, shuffle=False)\r\n",
"print('Train Dataset Size:', len(train_set))\r\n",
"print('Testing Dataset Size:', len(test_set))\r\n",
"print('Attribute tensor shape:', next(train_loader)[1].ndata['h'].size(1))\r\n",
"ATTR_COUNT = next(train_loader)[1].ndata['h'].size(1)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Processing Training Set.\n",
"Processing Testing Set.\n",
"Train Dataset Size: 20732\n",
"Testing Dataset Size: 5173\n",
"Attribute tensor shape: 26\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Step 3: Run and Test\n",
"\n",
"We can run the model and evaluate it now!"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"if torch.cuda.is_available():\r\n",
" print(\"Using CUDA.\")\r\n",
"# device = \"cpu\"\r\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\r\n",
"\r\n",
"# Start Training\r\n",
"load_model = False\r\n",
"if load_model:\r\n",
" model = GNN(ATTR_COUNT, 3, 400, 0.1).to(device)\r\n",
" opt = torch.optim.AdamW(model.parameters(), lr=4e-4)\r\n",
" checkpoint = torch.load('LatencyGNN.pt')\r\n",
" model.load_state_dict(checkpoint['model_state_dict'])\r\n",
" opt.load_state_dict(checkpoint['optimizer_state_dict'])\r\n",
" # EPOCHS = checkpoint['epoch']\r\n",
" EPOCHS = 0\r\n",
" loss_func = checkpoint['loss']\r\n",
"else:\r\n",
" model = GNN(ATTR_COUNT, 3, 400, 0.1).to(device)\r\n",
" opt = torch.optim.AdamW(model.parameters(), lr=4e-4)\r\n",
" EPOCHS=20\r\n",
" loss_func = nn.L1Loss()\r\n",
"\r\n",
"lr_scheduler = CosineAnnealingLR(opt, T_max=EPOCHS)\r\n",
"loss_sum = 0\r\n",
"for epoch in range(EPOCHS):\r\n",
" train_length = len(train_set)\r\n",
" tran_acc_ten = 0\r\n",
" loss_sum = 0 \r\n",
" # latency, graph, types, flops\r\n",
" for batched_l, batched_g in train_loader:\r\n",
" opt.zero_grad()\r\n",
" batched_l = batched_l.to(device).float()\r\n",
" batched_g = batched_g.to(device)\r\n",
" batched_f = batched_g.ndata['h'].float()\r\n",
" logits = model(batched_g, batched_f)\r\n",
" for i in range(len(batched_l)):\r\n",
" pred_latency = logits[i].item()\r\n",
" prec_latency = batched_l[i].item()\r\n",
" if (pred_latency >= 0.9 * prec_latency) and (pred_latency <= 1.1 * prec_latency):\r\n",
" tran_acc_ten += 1\r\n",
" # print(\"true latency: \", batched_l)\r\n",
" # print(\"Predict latency: \", logits)\r\n",
" batched_l = torch.reshape(batched_l, (-1 ,1))\r\n",
" loss = loss_func(logits, batched_l)\r\n",
" loss_sum += loss\r\n",
" loss.backward()\r\n",
" opt.step()\r\n",
" lr_scheduler.step()\r\n",
" print(\"[Epoch \", epoch, \"]: \", \"Training accuracy within 10%: \", tran_acc_ten / train_length * 100, \" %.\")\r\n",
" # print('Learning Rate:', lr_scheduler.get_last_lr())\r\n",
" # print('Loss:', loss_sum / train_length)\r\n",
"\r\n",
"# Save The Best Model\r\n",
"torch.save({\r\n",
" 'epoch': EPOCHS,\r\n",
" 'model_state_dict': model.state_dict(),\r\n",
" 'optimizer_state_dict': opt.state_dict(),\r\n",
" 'loss': loss_func,\r\n",
"}, 'LatencyGNN.pt')\r\n",
"\r\n",
"# Start Testing\r\n",
"count = 0\r\n",
"with torch.no_grad():\r\n",
" test_length = len(test_set)\r\n",
" test_acc_ten = 0\r\n",
" for batched_l, batched_g in test_loader:\r\n",
" batched_l = batched_l.to(device).float()\r\n",
" batched_g = batched_g.to(device)\r\n",
" batched_f = batched_g.ndata['h'].float()\r\n",
" result = model(batched_g, batched_f)\r\n",
" if (result.item() >= 0.9 * batched_l.item()) and (result.item() <= 1.1 * batched_l.item()):\r\n",
" test_acc_ten += 1\r\n",
" acc = (abs(result.item() - batched_l.item()) / batched_l.item()) * 100\r\n",
" count += 1\r\n",
" print(\"Testing accuracy within 10%: \", test_acc_ten / test_length * 100, \" %.\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[Epoch 0 ]: Training accuracy within 10%: 21.999807061547365 %.\n",
"[Epoch 1 ]: Training accuracy within 10%: 27.725255643449742 %.\n",
"[Epoch 2 ]: Training accuracy within 10%: 30.228632066370825 %.\n",
"[Epoch 3 ]: Training accuracy within 10%: 31.357322014277443 %.\n",
"[Epoch 4 ]: Training accuracy within 10%: 33.06000385876906 %.\n",
"[Epoch 5 ]: Training accuracy within 10%: 34.917036465367545 %.\n",
"[Epoch 6 ]: Training accuracy within 10%: 36.48466139301563 %.\n",
"[Epoch 7 ]: Training accuracy within 10%: 39.070036658306 %.\n",
"[Epoch 8 ]: Training accuracy within 10%: 40.10708084121165 %.\n",
"[Epoch 9 ]: Training accuracy within 10%: 41.530001929384525 %.\n",
"[Epoch 10 ]: Training accuracy within 10%: 43.26162454177118 %.\n",
"[Epoch 11 ]: Training accuracy within 10%: 45.34053636889832 %.\n",
"[Epoch 12 ]: Training accuracy within 10%: 48.45166891761528 %.\n",
"[Epoch 13 ]: Training accuracy within 10%: 50.945398417904684 %.\n",
"[Epoch 14 ]: Training accuracy within 10%: 54.5774647887324 %.\n",
"[Epoch 15 ]: Training accuracy within 10%: 56.08238471927455 %.\n",
"[Epoch 16 ]: Training accuracy within 10%: 59.54562994404785 %.\n",
"[Epoch 17 ]: Training accuracy within 10%: 62.41076596565696 %.\n",
"[Epoch 18 ]: Training accuracy within 10%: 63.65521898514373 %.\n",
"[Epoch 19 ]: Training accuracy within 10%: 64.6826162454177 %.\n",
"Testing accuracy within 10%: 60.042528513435144 %.\n"
]
}
],
"metadata": {}
}
],
"metadata": {
"interpreter": {
"hash": "0238da245144306487e61782d9cba9bf2e5e19842e5054371ac0cfbea9be2b57"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -1,3 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .bench_dataset import bench_dataset # TODO: add GNNDataloader and GNNDataset here @wenxuan
from .bench_dataset import bench_dataset
from .gnn_dataloader import GNNDataset, GNNDataloader

Просмотреть файл

@ -1,2 +1,272 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import torch
import jsonlines
from .bench_dataset import bench_dataset
import os
import random
import dgl
RAW_DATA_URL = "https://github.com/microsoft/nn-Meter/releases/download/v1.0-data/datasets.zip"
__user_dataset_folder__ = os.path.expanduser('~/.nn_meter/dataset')
hws = [
"cortexA76cpu_tflite21",
"adreno640gpu_tflite21",
"adreno630gpu_tflite21",
"myriadvpu_openvino2019r2",
]
class GNNDataset(torch.utils.data.Dataset):
def __init__(self, train=True, device="cortexA76cpu_tflite21", split_ratio=0.8):
"""
Dataloader of the Latency Dataset
Parameters
----------
data_dir : string
Path to save the downloaded dataset
train: bool
Get the train dataset or the test dataset
device: string
The Device type of the corresponding latency
shuffle: bool
If shuffle the dataset at the begining of an epoch
batch_size: int
Batch size.
split_ratio: float
The ratio to split the train dataset and the test dataset.
"""
err_str = "Not supported device type"
assert device in hws, err_str
self.device = device
self.data_dir = __user_dataset_folder__
self.train = train
self.split_ratio = split_ratio
self.adjs = {}
self.attrs = {}
self.nodename2id = {}
self.id2nodename = {}
self.op_types = set()
self.opname2id = {}
self.raw_data = {}
self.name_list = []
self.latencies = {}
self.download_data()
self.load_model_archs_and_latencies(self.data_dir)
self.construct_attrs()
self.name_list = list(
filter(lambda x: x in self.latencies, self.name_list))
def download_data(self):
datasets = bench_dataset()
def load_model_archs_and_latencies(self, data_dir):
filelist = os.listdir(data_dir)
for filename in filelist:
if os.path.splitext(filename)[-1] != '.jsonl':
continue
self.load_model(os.path.join(data_dir, filename))
def load_model(self, fpath):
"""
Load a concrete model type.
"""
# print('Loading models in ', fpath)
assert os.path.exists(fpath), '{} does not exists'.format(fpath)
with jsonlines.open(fpath) as reader:
_names = []
for obj in reader:
if obj[self.device]:
# print(obj['id'])
_names.append(obj['id'])
self.latencies[obj['id']] = float(obj[self.device])
_names = sorted(_names)
split_ratio = self.split_ratio if self.train else 1-self.split_ratio
count = int(len(_names) * split_ratio)
if self.train:
_model_names = _names[:count]
else:
_model_names = _names[-1*count:]
self.name_list.extend(_model_names)
with jsonlines.open(fpath) as reader:
for obj in reader:
if obj['id'] in _model_names:
model_name = obj['id']
model_data = obj['graph']
self.parse_model(model_name, model_data)
self.raw_data[model_name] = model_data
def construct_attrs(self):
"""
Construct the attributes matrix for each model.
Attributes tensor:
one-hot encoded type + input_channel , output_channel,
input_h, input_w + kernel_size + stride
"""
op_types_list = list(sorted(self.op_types))
for i, _op in enumerate(op_types_list):
self.opname2id[_op] = i
n_op_type = len(self.op_types)
attr_len = n_op_type + 6
for model_name in self.raw_data:
n_node = len(self.raw_data[model_name])
# print("Model: ", model_name, " Number of Nodes: ", n_node)
t_attr = torch.zeros(n_node, attr_len)
for node in self.raw_data[model_name]:
node_attr = self.raw_data[model_name][node]
nid = self.nodename2id[model_name][node]
op_type = node_attr['attr']['type']
op_id = self.opname2id[op_type]
t_attr[nid][op_id] = 1
other_attrs = self.parse_node(model_name, node)
t_attr[nid][-6:] = other_attrs
self.attrs[model_name] = t_attr
def parse_node(self, model_name, node_name):
"""
Parse the attributes of specified node
Get the input_c, output_c, input_h, input_w, kernel_size, stride
of this node. Note: filled with 0 by default if this doesn't have
coressponding attribute.
"""
node_data = self.raw_data[model_name][node_name]
t_attr = torch.zeros(6)
op_type = node_data['attr']['type']
if op_type =='Conv2D':
weight_shape = node_data['attr']['attr']['weight_shape']
kernel_size, _, in_c, out_c = weight_shape
stride, _= node_data['attr']['attr']['strides']
_, h, w, _ = node_data['attr']['output_shape'][0]
t_attr = torch.tensor([in_c, out_c, h, w, kernel_size, stride])
elif op_type == 'DepthwiseConv2dNative':
weight_shape = node_data['attr']['attr']['weight_shape']
kernel_size, _, in_c, out_c = weight_shape
stride, _= node_data['attr']['attr']['strides']
_, h, w, _ = node_data['attr']['output_shape'][0]
t_attr = torch.tensor([in_c, out_c, h, w, kernel_size, stride])
elif op_type == 'MatMul':
in_node = node_data['inbounds'][0]
in_shape = self.raw_data[model_name][in_node]['attr']['output_shape'][0]
in_c = in_shape[-1]
out_c = node_data['attr']['output_shape'][0][-1]
t_attr[0] = in_c
t_attr[1] = out_c
elif len(node_data['inbounds']):
in_node = node_data['inbounds'][0]
h, w, in_c, out_c = 0, 0, 0, 0
in_shape = self.raw_data[model_name][in_node]['attr']['output_shape'][0]
in_c = in_shape[-1]
if 'ConCat' in op_type:
for i in range(1, len(node_data['in_bounds'])):
in_shape = self.raw_data[node_data['in_bounds']
[i]]['attr']['output_shape'][0]
in_c += in_shape[-1]
if len(node_data['attr']['output_shape']):
out_shape = node_data['attr']['output_shape'][0]
# N, H, W, C
out_c = out_shape[-1]
if len(out_shape) == 4:
h, w = out_shape[1], out_shape[2]
t_attr[-6:-2] = torch.tensor([in_c, out_c, h, w])
return t_attr
def parse_model(self, model_name, model_data):
"""
Parse the model data and build the adjacent matrixes
"""
n_nodes = len(model_data)
m_adj = torch.zeros(n_nodes, n_nodes, dtype=torch.int32)
id2name = {}
name2id = {}
tmp_node_id = 0
# build the mapping between the node name and node id
for node_name in model_data.keys():
id2name[tmp_node_id] = node_name
name2id[node_name] = tmp_node_id
op_type = model_data[node_name]['attr']['type']
self.op_types.add(op_type)
tmp_node_id += 1
for node_name in model_data:
cur_id = name2id[node_name]
for node in model_data[node_name]['inbounds']:
if node not in name2id:
# weight node
continue
in_id = name2id[node]
m_adj[in_id][cur_id] = 1
for node in model_data[node_name]['outbounds']:
if node not in name2id:
# weight node
continue
out_id = name2id[node]
m_adj[cur_id][out_id] = 1
for idx in range(n_nodes):
m_adj[idx][idx] = 1
self.adjs[model_name] = m_adj
self.nodename2id[model_name] = name2id
self.id2nodename[model_name] = id2name
def __getitem__(self, index):
model_name = self.name_list[index]
return (self.adjs[model_name], self.attrs[model_name]), self.latencies[model_name], self.op_types
def __len__(self):
return len(self.name_list)
class GNNDataloader(torch.utils.data.DataLoader):
def __init__(self, dataset, shuffle=False, batchsize=1):
self.dataset = dataset
self.op_num = len(dataset.op_types)
self.shuffle = shuffle
self.batchsize = batchsize
self.length = len(self.dataset)
self.indexes = list(range(self.length))
self.pos = 0
self.graphs = {}
self.latencies = {}
self.construct_graphs()
def construct_graphs(self):
for gid in range(self.length):
(adj, attrs), latency, op_types = self.dataset[gid]
u, v = torch.nonzero(adj, as_tuple=True)
# import pdb; pdb.set_trace()
graph = dgl.graph((u, v))
MAX_NORM = torch.tensor([1]*len(op_types) + [6963, 6963, 224, 224, 11, 4])
attrs = attrs / MAX_NORM
graph.ndata['h'] = attrs
self.graphs[gid] = graph
self.latencies[gid] = latency
def __iter__(self):
if self.shuffle:
random.shuffle(self.indexes)
self.pos = 0
return self
def __len__(self):
return self.length
def __next__(self):
start = self.pos
end = min(start + self.batchsize, self.length)
self.pos = end
if end - start <= 0:
raise StopIteration
batch_indexes = self.indexes[start:end]
batch_graphs = [self.graphs[i] for i in batch_indexes]
batch_latencies = [self.latencies[i] for i in batch_indexes]
return torch.tensor(batch_latencies), dgl.batch(batch_graphs)