added basic network plotting
This commit is contained in:
Родитель
e55cd0afa8
Коммит
a9ac4e090a
|
@ -29,6 +29,66 @@ def dfs_walk(node, visitor, accum, visited):
|
||||||
if visitor(node):
|
if visitor(node):
|
||||||
accum.append(node)
|
accum.append(node)
|
||||||
|
|
||||||
|
|
||||||
|
def build_graph(node, visitor, accum, visited, dot_object):
|
||||||
|
'''
|
||||||
|
Generic function to build the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (graph node): the node to start the journey from
|
||||||
|
visitor (Python function or lambda): function that takes a node as
|
||||||
|
argument and returns `True` if that node should be returned.
|
||||||
|
accum (`list`): accumulator of nodes while traversing the graph
|
||||||
|
visited (`set`): set of nodes that have already been visited.
|
||||||
|
Initialize with empty set.
|
||||||
|
dot_object(`Pydot.Dot`): contains the graph description in
|
||||||
|
dot format
|
||||||
|
'''
|
||||||
|
import pydot
|
||||||
|
if node in visited:
|
||||||
|
return
|
||||||
|
visited.add(node)
|
||||||
|
|
||||||
|
if hasattr(node, 'root_function'):
|
||||||
|
node = node.root_function
|
||||||
|
cur_node = pydot.Node(node.op_name+' '+node.uid, label=node.op_name)
|
||||||
|
dot_object.add_node(cur_node)
|
||||||
|
out_node = pydot.Node(node.outputs[0].uid, label=node.outputs[0].uid
|
||||||
|
+ '\nshape:\n' + str(node.outputs[0].shape))
|
||||||
|
dot_object.add_node(out_node)
|
||||||
|
dot_object.add_edge(pydot.Edge(cur_node,out_node))
|
||||||
|
for child in node.inputs:
|
||||||
|
child_node = pydot.Node(child.uid, label=child.uid + '\nshape:\n' + str(child.shape))
|
||||||
|
dot_object.add_node(child_node)
|
||||||
|
dot_object.add_edge(pydot.Edge(child_node, cur_node))
|
||||||
|
dfs_walk_plot(child, visitor, accum, visited, dot_object)
|
||||||
|
|
||||||
|
elif hasattr(node, 'is_output') and node.is_output:
|
||||||
|
dfs_walk_plot(node.owner, visitor, accum, visited, dot_object)
|
||||||
|
|
||||||
|
if visitor(node):
|
||||||
|
accum.append(node)
|
||||||
|
|
||||||
|
|
||||||
|
def png_graph(model, path):
|
||||||
|
'''
|
||||||
|
Saves the network graph to the file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model(`cntk.ops.functions.Function`): model to plot
|
||||||
|
path(`str`): path to the save directory
|
||||||
|
'''
|
||||||
|
import pydot
|
||||||
|
dot_object = pydot.Dot(graph_name="network_graph",rankdir='LR')
|
||||||
|
dot_object.set_node_defaults(shape='circle', fixedsize='false',
|
||||||
|
height=.85, width=.85, fontsize=10)
|
||||||
|
|
||||||
|
accum = []
|
||||||
|
dfs_walk_plot(model, lambda x: True, accum, set(), dot_object)
|
||||||
|
dot_object.write_png(path + '\\network_graph.png', prog='dot')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def visit(node, visitor):
|
def visit(node, visitor):
|
||||||
'''
|
'''
|
||||||
Generic function that walks through the graph starting at `node` and
|
Generic function that walks through the graph starting at `node` and
|
||||||
|
|
Загрузка…
Ссылка в новой задаче