This commit is contained in:
Alona Kharchenko 2016-10-25 20:43:04 +02:00
Родитель e55cd0afa8
Коммит a9ac4e090a
1 изменённых файлов: 60 добавлений и 0 удалений

Просмотреть файл

@ -29,6 +29,66 @@ def dfs_walk(node, visitor, accum, visited):
if visitor(node):
accum.append(node)
def build_graph(node, visitor, accum, visited, dot_object):
'''
Generic function to build the graph.
Args:
node (graph node): the node to start the journey from
visitor (Python function or lambda): function that takes a node as
argument and returns `True` if that node should be returned.
accum (`list`): accumulator of nodes while traversing the graph
visited (`set`): set of nodes that have already been visited.
Initialize with empty set.
dot_object(`Pydot.Dot`): contains the graph description in
dot format
'''
import pydot
if node in visited:
return
visited.add(node)
if hasattr(node, 'root_function'):
node = node.root_function
cur_node = pydot.Node(node.op_name+' '+node.uid, label=node.op_name)
dot_object.add_node(cur_node)
out_node = pydot.Node(node.outputs[0].uid, label=node.outputs[0].uid
+ '\nshape:\n' + str(node.outputs[0].shape))
dot_object.add_node(out_node)
dot_object.add_edge(pydot.Edge(cur_node,out_node))
for child in node.inputs:
child_node = pydot.Node(child.uid, label=child.uid + '\nshape:\n' + str(child.shape))
dot_object.add_node(child_node)
dot_object.add_edge(pydot.Edge(child_node, cur_node))
dfs_walk_plot(child, visitor, accum, visited, dot_object)
elif hasattr(node, 'is_output') and node.is_output:
dfs_walk_plot(node.owner, visitor, accum, visited, dot_object)
if visitor(node):
accum.append(node)
def png_graph(model, path):
'''
Saves the network graph to the file
Args:
model(`cntk.ops.functions.Function`): model to plot
path(`str`): path to the save directory
'''
import pydot
dot_object = pydot.Dot(graph_name="network_graph",rankdir='LR')
dot_object.set_node_defaults(shape='circle', fixedsize='false',
height=.85, width=.85, fontsize=10)
accum = []
dfs_walk_plot(model, lambda x: True, accum, set(), dot_object)
dot_object.write_png(path + '\\network_graph.png', prog='dot')
def visit(node, visitor):
'''
Generic function that walks through the graph starting at `node` and