Move inconsistency resolve from ir_converters to kerneldetection; Resolve inconsistency between ONNX and Tensorflow

This commit is contained in:
Jianyu Wei 2021-05-20 02:37:18 -04:00
Родитель 251f9b719b
Коммит ef4e4e7582
12 изменённых файлов: 153 добавлений и 162 удалений

Просмотреть файл

@ -18,31 +18,3 @@ TRANSPOSE_TYPE = 'Transpose'
REDUCEMEAN_TYPE = 'ReduceMean'
SPLIT_TYPE = 'Split'
PAD_TYPE = 'Pad'
OP_ALIAS = {
CONV_TYPE: 'conv',
BN_TYPE: 'bn',
SLICE_TYPE: 'split',
CONCAT_TYPE: 'concat',
MAXPOOL_TYPE: 'maxpool',
AVGPOOL_TYPE: 'avgpool',
RELU_TYPE: 'relu',
ADD_TYPE: 'add',
FC_TYPE: 'fc',
RESHAPE_TYPE: 'reshape',
GAP_TYPE: 'gap',
CLIP_TYPE: 'clip',
MUL_TYPE: 'mul',
DIV_TYPE: 'div',
HARDSIGMOID_TYPE: 'hardsigmoid',
FLATTEN_TYPE: 'flatten',
TRANSPOSE_TYPE: 'transpose',
REDUCEMEAN_TYPE: 'reducemean',
SPLIT_TYPE: 'split',
PAD_TYPE: 'pad',
}
ATTR_ALIAS = {
'pads': ('padding', '__all__'),
'axis': ('split_dim', ['split']),
}

Просмотреть файл

@ -1,5 +1,5 @@
import networkx as nx
from .utils import get_tensor_shape, convert_attr
from .utils import get_tensor_shape
from .constants import *
from itertools import chain
import logging
@ -75,11 +75,8 @@ class OnnxConverter:
if len(input_tensors) == 0 or len(input_tensors[0]) <= 1 or len(output_tensors) == 0 or len(output_tensors[0]) <= 1:
return attrs
if node.op_type not in OP_ALIAS:
logging.warning(f'Unsupported OP: {node.op_type}')
attrs['attr'] = {}
attrs['type'] = OP_ALIAS.get(node.op_type, node.op_type)
attrs['type'] = node.op_type
attrs['input_shape'] = input_tensors
attrs['output_shape'] = output_tensors
for attr in node.attribute:
@ -103,7 +100,6 @@ class OnnxConverter:
node_attrs = self.G.nodes[node]
if node in self.tensors or not node_attrs:
continue
node_attrs['attr'] = convert_attr(node_attrs['attr'], node_attrs['type'])
outbounds = []
inbounds = []

Просмотреть файл

@ -8,24 +8,3 @@ def get_tensor_shape(tensor):
if len(shape) == 4:
shape = [shape[0], shape[2], shape[3], shape[1]]
return shape
def convert_attr(attr, type):
def is_type(type, ts):
if ts is None:
return False
elif ts == '__all__':
return True
else:
return type in ts
new_attr = {}
for name, value in attr.items():
new_name, ts = ATTR_ALIAS.get(name, (name, None))
if is_type(type, ts):
new_attr[new_name] = value
else:
new_attr[name] = value
return new_attr

Просмотреть файл

@ -1,15 +1,10 @@
import onnx
import json
from .onnx_converter import OnnxConverter
from .frozenpb_converter import FrozenPbConverter
def model_to_grapher(model, model_type=None):
if model_type is None:
if isinstance(model, onnx.ModelProto):
model_type = 'onnx'
else:
raise ValueError(f'Invalid model: {type(model)}')
def model_to_grapher(model, model_type):
if model_type == 'onnx':
converter = OnnxConverter(model)
result = converter.convert()
@ -26,16 +21,21 @@ def model_file_to_grapher(filename, model_type=None):
if filename.endswith('.onnx'):
model_type = 'onnx'
elif filename.endswith('.pb'):
converter = FrozenPbConverter(filename)
return converter.get_flatten_grapher()
model_type = 'pb'
elif filename.endswith('.json'):
model_type = 'json'
else:
raise ValueError(f'Unknown file type: {filename}')
if model_type == 'onnx':
model = onnx.load(filename)
return model_to_grapher(model, model_type)
elif model_type == 'pb':
raise NotImplementedError
converter = FrozenPbConverter(filename)
return converter.get_flatten_grapher()
elif model_type == 'json':
with open(filename, 'r') as fp:
return json.load(fp)
else:
raise ValueError(f'Unsupported model type: {model_type}')
return model_to_grapher(model, model_type)

33
kernel_detection.py Normal file
Просмотреть файл

@ -0,0 +1,33 @@
from kerneldetection import KernelDetector
from ir_converters import model_file_to_grapher
import argparse
import json
BACKENDS = {
'cpu': 'tflite_cpu',
'gpu': 'tflite_gpu',
'vpu': 'vpu',
}
def main(input_model, rule_file, output_path):
graph = model_file_to_grapher(input_model)
kd = KernelDetector(rule_file)
kd.load_graph(graph)
with open(output_path, 'w') as fp:
json.dump(kd.kernels, fp, indent=2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--hardware', type=str, default='cpu')
parser.add_argument('-i', '--input_model', type=str, required=True, help='Path to input models. Either pb or onnx')
parser.add_argument('-o', '--output_path', type=str, default='out.json')
parser.add_argument('-r', '--rule_file', type=str, default='data/fusionrules/rule_tflite_cpu.json')
#parser.add_argument('-t', '--input_type', type=str, choices=['multi-m','single-m'], default='multi-m', help='input file type: multi-m or single-m')
#parser.add_argument('-backend', '--backend', type=str, choices=['tflite_cpu','tflite_gpu','vpu'], default='tflite_cpu', help='Default preserve the original layer names. Readable will assign new kernel names according to types of the layers.')
args = parser.parse_args()
main(args.input_model, args.rule_file, args.output_path)

Просмотреть файл

@ -1,24 +0,0 @@
DUMMY_TYPES = [
'Const',
'Identity',
'Placeholder',
]
TENSORFLOW_OP_ALIAS = {
'Relu6': 'relu',
'Relu': 'relu',
'Add': 'add',
'Biasadd': 'add',
'Conv2D': 'conv',
'Reshape': 'reshape',
'FusedBatchNorm': 'bn',
'FusedBatchNormV3': 'bn',
'MatMul': 'fc',
'MaxPool': 'maxpool',
'AvgPool': 'avgpool',
'Mean': 'gap',
'Mul': 'mul',
'DepthwiseConv2dNative': 'dwconv',
'ConcatV2': 'concat',
'Split': 'split',
}

Просмотреть файл

@ -1,7 +1,8 @@
from kerneldetection.rulelib.rule_reader import RuleReader
from kerneldetection.rulelib.rule_splitter import RuleSplitter
from utils.grapher_tool import Grapher
from .constants import DUMMY_TYPES
from kerneldetection.utils.constants import *
from kerneldetection.utils.ir_tools import convert_nodes
class KernelDetector:
@ -12,7 +13,8 @@ class KernelDetector:
self.bbs = []
def load_graph(self, graph):
self.graph = Grapher(graph=graph)
new_graph = convert_nodes(graph)
self.graph = Grapher(graph=new_graph)
self.bbs = self.splitter.split(self.graph)
@property
@ -43,35 +45,25 @@ class KernelDetector:
attr = self.graph.get_node_attr(layer)['attr']
input_shape = self.graph.get_node_attr(layer)['input_shape']
output_shape = self.graph.get_node_attr(layer)['output_shape']
if type in ['conv', 'dwconv']:
kernel['ks'] = attr['ks']
kernel['cin'] = input_shape[0][3]
kernel['cout'] = output_shape[0][3]
kernel['strides'] = attr['strides']
if type == 'dwconv':
kernel['cout'] = kernel['cin']
elif type in ['maxpool', 'avgpool']:
kernel['ks'] = attr['ksize']
kernel['cin'] = input_shape[0][3]
kernel['cout'] = output_shape[0][3]
kernel['strides'] = attr['strides']
elif type == 'fc':
kernel['cin'] = input_shape[0][1]
kernel['cout'] = output_shape[0][1]
elif type == 'gap':
kernel['cin'] = input_shape[0][3]
kernel['cout'] = output_shape[0][3]
elif type in ['relu','hswish']:
kernel['cin'] = input_shape[-1]
kernel['cout'] = output_shape[-1]
kernel['input_tensors'] = input_shape
if type not in ['relu','bn', 'fc', 'reshape', 'Pack', 'StridedSlice','split']:
kernel['inputh'] = input_shape[0][1]
kernel['inputw'] = input_shape[0][2]
if type == 'split':
if 'ks' in attr:
kernel['ks'] = attr['ks']
if 'strides' in attr:
kernel['strides'] = attr['strides']
if 'split_dim' in attr:
kernel['split_dim'] = attr['split_dim']
if len(input_shape) == 1:
if len(input_shape[0]) == 4:
kernel['inputh'] = input_shape[0][1]
kernel['inputw'] = input_shape[0][2]
kernel['cin'] = input_shape[0][-1]
if len(output_shape) == 1:
kernel['cout'] = output_shape[0][-1]
elif len(output_shape) > 1:
kernel['output_tensors'] = output_shape
return kernel

Просмотреть файл

@ -1,9 +1,13 @@
import os
import json
from utils.grapher_tool import Grapher
from kerneldetection.utils.ir_tools import convert_nodes
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def get_fusion_unit(name):
filename = os.path.join(BASE_DIR, f'{name}_fusionunit.json')
return Grapher(filename)
with open(filename, 'r') as fp:
graph = convert_nodes(json.load(fp))
return Grapher(graph=graph)

Просмотреть файл

@ -0,0 +1,44 @@
DUMMY_TYPES = [
'Const',
'Identity',
'Placeholder',
]
OP_ALIAS = {
# Tensorflow
'Relu6': 'relu',
'Relu': 'relu',
'Add': 'add',
'Biasadd': 'add',
'Conv2D': 'conv',
'Reshape': 'reshape',
'FusedBatchNorm': 'bn',
'FusedBatchNormV3': 'bn',
'MatMul': 'fc',
'MaxPool': 'maxpool',
'AvgPool': 'avgpool',
'Mean': 'gap',
'Mul': 'mul',
'DepthwiseConv2dNative': 'dwconv',
'ConcatV2': 'concat',
'Split': 'split',
# ONNX
'Conv': 'conv',
'BatchNormalization': 'bn',
'Slice': 'split',
'Concat': 'concat',
'AveragePool': 'avgpool',
'Relu': 'relu',
'Add': 'add',
'Gemm': 'fc',
'GlobalAveragePool': 'gap',
'Clip': 'relu',
'Mul': 'mul',
'Div': 'div',
'HardSigmoid': 'hardsigmoid',
'Flatten': 'reshape',
'Transpose': 'transpose',
'ReduceMean': 'gap',
'Split': 'split',
'Pad': 'pad',
}

Просмотреть файл

@ -0,0 +1,33 @@
import copy
from .constants import *
def convert_nodes(graph):
'''
Resolve inconsistency between ONNX and Tensorflow
'''
new_graph = copy.deepcopy(graph)
for _, node in new_graph.items():
type = node['attr']['type']
new_type = OP_ALIAS.get(type, type)
node['attr']['type'] = new_type
attr = node['attr']['attr']
if 'kernel_shape' in attr:
attr['ks'] = attr['kernel_shape']
del attr['kernel_shape']
if 'weight_shape' in attr and attr['weight_shape'] is not None:
attr['ks'] = attr['weight_shape'][0:2]
del attr['weight_shape']
if 'ksize' in attr:
attr['ks'] = attr['ksize']
del attr['ksize']
if new_type == 'split' and 'axis' in attr:
attr['split_dim'] = attr['axis']
del attr['split_dim']
return new_graph

Просмотреть файл

@ -1,53 +1,13 @@
class MatchHelper:
base_type_table = {
'ReLU': [
'Relu',
'Relu6',
'ReLU',
'ReLU6',
],
'BatchNorm': [
'BatchNorm',
'FusedBatchNorm',
'FusedBatchNormV2',
'FusedBatchNormV3',
],
'TwoInputElementWise': [
'BiasAdd',
'Add',
'Mul',
],
'DepthwiseConv2D': [
'DepthwiseConv2dNative',
],
'FC': [
'MatMul',
]
}
@classmethod
def get_base_type(cls, node_type):
for key, value in cls.base_type_table.items():
if node_type in value:
return key
return node_type
@classmethod
def op_type_matcher(cls, node_1, node_2):
def get_ast_by_op(op_name):
for key, value in cls.base_type_table.items():
if op_name in value:
return key
return op_name
if 'type' in node_1 and 'type' in node_2:
if '_tagged' in node_1 or '_tagged' in node_2:
return False
if node_1['type'] == 'dummy' or node_2['type'] == 'dummy':
return True
return get_ast_by_op(node_1['type']) == get_ast_by_op(node_2['type'])
return node_1['type'] == node_2['type']
else:
return False

Просмотреть файл

@ -83,7 +83,7 @@ class Grapher:
self.graph[name]['inbounds'].remove(inbound)
def remove_node_outbounds(self, name, outbound):
if outbound in self.graph[name]['outbound']:
if outbound in self.graph[name]['outbounds']:
self.graph[name]['outbounds'].remove(outbound)
def add_node_inbounds(self, name, inbound):
@ -132,7 +132,7 @@ class Grapher:
return None
def get_root_node(self, subgraph):
root = subgraph[0]
root = next(iter(subgraph))
flag = True
while flag:
@ -143,6 +143,8 @@ class Grapher:
root = inbound
break
return root
def fuse(self, subgraph, type, name=None, attr=None, is_block=True):
'''
subgraph: list of node name