зеркало из https://github.com/microsoft/caffe.git
[tools] add Python script for at-a-glance prototxt summary
This commit is contained in:
Родитель
2e1c1cb778
Коммит
84eb44e6cf
|
@ -0,0 +1,140 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""Net summarization tool.
|
||||
|
||||
This tool summarizes the structure of a net in a concise but comprehensive
|
||||
tabular listing, taking a prototxt file as input.
|
||||
|
||||
Use this tool to check at a glance that the computation you've specified is the
|
||||
computation you expect.
|
||||
"""
|
||||
|
||||
from caffe.proto import caffe_pb2
|
||||
from google import protobuf
|
||||
import re
|
||||
import argparse
|
||||
|
||||
# ANSI codes for coloring blobs (used cyclically)
|
||||
COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
|
||||
'444', '103;30', '107;30']
|
||||
DISCONNECTED_COLOR = '41'
|
||||
|
||||
def read_net(filename):
|
||||
net = caffe_pb2.NetParameter()
|
||||
with open(filename) as f:
|
||||
protobuf.text_format.Parse(f.read(), net)
|
||||
return net
|
||||
|
||||
def format_param(param):
|
||||
out = []
|
||||
if len(param.name) > 0:
|
||||
out.append(param.name)
|
||||
if param.lr_mult != 1:
|
||||
out.append('x{}'.format(param.lr_mult))
|
||||
if param.decay_mult != 1:
|
||||
out.append('Dx{}'.format(param.decay_mult))
|
||||
return ' '.join(out)
|
||||
|
||||
def printed_len(s):
|
||||
return len(re.sub(r'\033\[[\d;]+m', '', s))
|
||||
|
||||
def print_table(table, max_width):
|
||||
"""Print a simple nicely-aligned table.
|
||||
|
||||
table must be a list of (equal-length) lists. Columns are space-separated,
|
||||
and as narrow as possible, but no wider than max_width. Text may overflow
|
||||
columns; note that unlike string.format, this will not affect subsequent
|
||||
columns, if possible."""
|
||||
|
||||
max_widths = [max_width] * len(table[0])
|
||||
column_widths = [max(printed_len(row[j]) + 1 for row in table)
|
||||
for j in range(len(table[0]))]
|
||||
column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]
|
||||
|
||||
for row in table:
|
||||
row_str = ''
|
||||
right_col = 0
|
||||
for cell, width in zip(row, column_widths):
|
||||
right_col += width
|
||||
row_str += cell + ' '
|
||||
row_str += ' ' * max(right_col - printed_len(row_str), 0)
|
||||
print row_str
|
||||
|
||||
def summarize_net(net):
|
||||
disconnected_tops = set()
|
||||
for lr in net.layer:
|
||||
disconnected_tops |= set(lr.top)
|
||||
disconnected_tops -= set(lr.bottom)
|
||||
|
||||
table = []
|
||||
colors = {}
|
||||
for lr in net.layer:
|
||||
tops = []
|
||||
for ind, top in enumerate(lr.top):
|
||||
color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
|
||||
if top in disconnected_tops:
|
||||
top = '\033[1;4m' + top
|
||||
if len(lr.loss_weight) > 0:
|
||||
top = '{} * {}'.format(lr.loss_weight[ind], top)
|
||||
tops.append('\033[{}m{}\033[0m'.format(color, top))
|
||||
top_str = ', '.join(tops)
|
||||
|
||||
bottoms = []
|
||||
for bottom in lr.bottom:
|
||||
color = colors.get(bottom, DISCONNECTED_COLOR)
|
||||
bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
|
||||
bottom_str = ', '.join(bottoms)
|
||||
|
||||
if lr.type == 'Python':
|
||||
type_str = lr.python_param.module + '.' + lr.python_param.layer
|
||||
else:
|
||||
type_str = lr.type
|
||||
|
||||
# Summarize conv/pool parameters.
|
||||
# TODO support rectangular/ND parameters
|
||||
conv_param = lr.convolution_param
|
||||
if (lr.type in ['Convolution', 'Deconvolution']
|
||||
and len(conv_param.kernel_size) == 1):
|
||||
arg_str = str(conv_param.kernel_size[0])
|
||||
if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
|
||||
arg_str += '/' + str(conv_param.stride[0])
|
||||
if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
|
||||
arg_str += '+' + str(conv_param.pad[0])
|
||||
arg_str += ' ' + str(conv_param.num_output)
|
||||
if conv_param.group != 1:
|
||||
arg_str += '/' + str(conv_param.group)
|
||||
elif lr.type == 'Pooling':
|
||||
arg_str = str(lr.pooling_param.kernel_size)
|
||||
if lr.pooling_param.stride != 1:
|
||||
arg_str += '/' + str(lr.pooling_param.stride)
|
||||
if lr.pooling_param.pad != 0:
|
||||
arg_str += '+' + str(lr.pooling_param.pad)
|
||||
else:
|
||||
arg_str = ''
|
||||
|
||||
if len(lr.param) > 0:
|
||||
param_strs = map(format_param, lr.param)
|
||||
if max(map(len, param_strs)) > 0:
|
||||
param_str = '({})'.format(', '.join(param_strs))
|
||||
else:
|
||||
param_str = ''
|
||||
else:
|
||||
param_str = ''
|
||||
|
||||
table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
|
||||
arg_str])
|
||||
return table
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
|
||||
parser.add_argument('filename', help='net prototxt file to summarize')
|
||||
parser.add_argument('-w', '--max-width', help='maximum field width',
|
||||
type=int, default=30)
|
||||
args = parser.parse_args()
|
||||
|
||||
net = read_net(args.filename)
|
||||
table = summarize_net(net)
|
||||
print_table(table, max_width=args.max_width)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Загрузка…
Ссылка в новой задаче