Update the drawnet.py to reflect the recent revised net definition.

This commit is contained in:
ZhiHeng NIU 2014-04-25 17:43:56 +08:00
Родитель c55ebd0981
Коммит 65015e3712
1 изменённых файлов: 11 добавлений и 4 удалений

Просмотреть файл

@ -15,14 +15,21 @@ NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90',
'style': 'filled'}
BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
'style': 'filled'}
def get_enum_name_by_value():
desc = caffe_pb2.LayerParameter.LayerType.DESCRIPTOR
d = {}
for k,v in desc.values_by_name.items():
d[v.number] = k
return d
def get_pydot_graph(caffe_net):
pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph', rankdir="BT")
pydot_nodes = {}
pydot_edges = []
d = get_enum_name_by_value()
for layer in caffe_net.layers:
name = layer.layer.name
layertype = layer.layer.type
name = layer.name
layertype = d[layer.type]
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
layer.bottom[0] == layer.top[0]):
# We have an in-place neuron layer.
@ -63,7 +70,7 @@ def draw_net_to_file(caffe_net, filename):
to graphviz to draw graphs.
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'w') as fid:
with open(filename, 'wb') as fid:
fid.write(draw_net(caffe_net, ext))
if __name__ == '__main__':