This commit is contained in:
Willi Richert 2016-10-18 17:44:56 +02:00
Родитель 915455c730
Коммит 1ce9ddefdf
2 изменённых файлов: 43 добавлений и 17 удалений

Просмотреть файл

@ -3,33 +3,59 @@
# for full license information.
# ==============================================================================
def dfs_walk(node, visitor, accum):
def dfs_walk(node, visitor, accum, visited):
'''
Generic function to walk the graph.
Args:
node (graph node): the node to start the journey from
visitor (Python function or lambda): function that takes a node as
argument and returns `True` if that node should be returned.
accum (`list`): accumulator of nodes while traversing the graph
visited (`set`): set of nodes that have already been visited.
Initialize with empty set.
'''
if node in visited:
return
visited.add(node)
if hasattr(node, 'root_function'):
node = node.root_function
for child in node.inputs:
dfs_walk(child, visitor, accum)
dfs_walk(child, visitor, accum, visited)
elif hasattr(node, 'is_output') and node.is_output:
dfs_walk(node.owner, visitor, accum)
dfs_walk(node.owner, visitor, accum, visited)
if visitor(node):
accum.append(node)
def visit(root_node, visitor):
nodes = []
dfs_walk(root_node, visitor, nodes)
return nodes
def find_nodes_by_name(root_node, node_name):
def visit(node, visitor):
'''
Return a list of nodes having the given name
Generic function that walks through the graph starting at `node` and
applies function `visitor` on each of those.
Args:
root_node (node in the graph): root node from where the search should
start
node_name (`str`): name of the nodes
node (graph node): the node to start the journey from
visitor (Python function or lambda): function that takes a node as
argument and returns `True` if that node should be returned.
Returns:
a list of nodes having the given name
List of nodes, for which `visitor` was `True`
'''
return visit(root_node, lambda x: x.name == node_name)
nodes = []
dfs_walk(node, visitor, nodes, set())
return nodes
def find_nodes_by_name(node, node_name):
'''
Finds nodes in the graph starting from `node` and doing a depth-first
search.
Args:
node (graph node): the node to start the journey from
node_name (`str`): name for which we are search nodes
Returns:
List of nodes having the specified name
'''
return visit(node, lambda x: x.name == node_name)

Просмотреть файл

@ -8,7 +8,7 @@ import numpy as np
import pytest
from ..initializer import *
from .. import parameter, input_variable, momentums_per_sample
from .. import parameter
def _check(init, name):