This commit is contained in:
Willi Richert 2016-10-18 17:44:56 +02:00
Родитель 915455c730
Коммит 1ce9ddefdf
2 изменённых файлов: 43 добавлений и 17 удалений

Просмотреть файл

@ -3,33 +3,59 @@
# for full license information. # for full license information.
# ============================================================================== # ==============================================================================
def dfs_walk(node, visitor, accum): def dfs_walk(node, visitor, accum, visited):
'''
Generic function to walk the graph.
Args:
node (graph node): the node to start the journey from
visitor (Python function or lambda): function that takes a node as
argument and returns `True` if that node should be returned.
accum (`list`): accumulator of nodes while traversing the graph
visited (`set`): set of nodes that have already been visited.
Initialize with empty set.
'''
if node in visited:
return
visited.add(node)
if hasattr(node, 'root_function'): if hasattr(node, 'root_function'):
node = node.root_function node = node.root_function
for child in node.inputs: for child in node.inputs:
dfs_walk(child, visitor, accum) dfs_walk(child, visitor, accum, visited)
elif hasattr(node, 'is_output') and node.is_output: elif hasattr(node, 'is_output') and node.is_output:
dfs_walk(node.owner, visitor, accum) dfs_walk(node.owner, visitor, accum, visited)
if visitor(node): if visitor(node):
accum.append(node) accum.append(node)
def visit(root_node, visitor): def visit(node, visitor):
nodes = []
dfs_walk(root_node, visitor, nodes)
return nodes
def find_nodes_by_name(root_node, node_name):
''' '''
Return a list of nodes having the given name Generic function that walks through the graph starting at `node` and
applies function `visitor` on each of those.
Args: Args:
root_node (node in the graph): root node from where the search should node (graph node): the node to start the journey from
start visitor (Python function or lambda): function that takes a node as
node_name (`str`): name of the nodes argument and returns `True` if that node should be returned.
Returns: Returns:
a list of nodes having the given name List of nodes, for which `visitor` was `True`
''' '''
return visit(root_node, lambda x: x.name == node_name) nodes = []
dfs_walk(node, visitor, nodes, set())
return nodes
def find_nodes_by_name(node, node_name):
'''
Finds nodes in the graph starting from `node` and doing a depth-first
search.
Args:
node (graph node): the node to start the journey from
node_name (`str`): name for which we are search nodes
Returns:
List of nodes having the specified name
'''
return visit(node, lambda x: x.name == node_name)

Просмотреть файл

@ -8,7 +8,7 @@ import numpy as np
import pytest import pytest
from ..initializer import * from ..initializer import *
from .. import parameter, input_variable, momentums_per_sample from .. import parameter
def _check(init, name): def _check(init, name):