This commit is contained in:
Родитель
915455c730
Коммит
1ce9ddefdf
|
@ -3,33 +3,59 @@
|
|||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
def dfs_walk(node, visitor, accum):
|
||||
def dfs_walk(node, visitor, accum, visited):
|
||||
'''
|
||||
Generic function to walk the graph.
|
||||
|
||||
Args:
|
||||
node (graph node): the node to start the journey from
|
||||
visitor (Python function or lambda): function that takes a node as
|
||||
argument and returns `True` if that node should be returned.
|
||||
accum (`list`): accumulator of nodes while traversing the graph
|
||||
visited (`set`): set of nodes that have already been visited.
|
||||
Initialize with empty set.
|
||||
'''
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
if hasattr(node, 'root_function'):
|
||||
node = node.root_function
|
||||
for child in node.inputs:
|
||||
dfs_walk(child, visitor, accum)
|
||||
dfs_walk(child, visitor, accum, visited)
|
||||
elif hasattr(node, 'is_output') and node.is_output:
|
||||
dfs_walk(node.owner, visitor, accum)
|
||||
dfs_walk(node.owner, visitor, accum, visited)
|
||||
|
||||
if visitor(node):
|
||||
accum.append(node)
|
||||
|
||||
def visit(root_node, visitor):
|
||||
nodes = []
|
||||
dfs_walk(root_node, visitor, nodes)
|
||||
return nodes
|
||||
|
||||
def find_nodes_by_name(root_node, node_name):
|
||||
def visit(node, visitor):
|
||||
'''
|
||||
Return a list of nodes having the given name
|
||||
Generic function that walks through the graph starting at `node` and
|
||||
applies function `visitor` on each of those.
|
||||
|
||||
Args:
|
||||
root_node (node in the graph): root node from where the search should
|
||||
start
|
||||
node_name (`str`): name of the nodes
|
||||
node (graph node): the node to start the journey from
|
||||
visitor (Python function or lambda): function that takes a node as
|
||||
argument and returns `True` if that node should be returned.
|
||||
|
||||
Returns:
|
||||
a list of nodes having the given name
|
||||
|
||||
List of nodes, for which `visitor` was `True`
|
||||
'''
|
||||
return visit(root_node, lambda x: x.name == node_name)
|
||||
nodes = []
|
||||
dfs_walk(node, visitor, nodes, set())
|
||||
return nodes
|
||||
|
||||
def find_nodes_by_name(node, node_name):
|
||||
'''
|
||||
Finds nodes in the graph starting from `node` and doing a depth-first
|
||||
search.
|
||||
|
||||
Args:
|
||||
node (graph node): the node to start the journey from
|
||||
node_name (`str`): name for which we are search nodes
|
||||
|
||||
Returns:
|
||||
List of nodes having the specified name
|
||||
'''
|
||||
return visit(node, lambda x: x.name == node_name)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from ..initializer import *
|
||||
from .. import parameter, input_variable, momentums_per_sample
|
||||
from .. import parameter
|
||||
|
||||
|
||||
def _check(init, name):
|
||||
|
|
Загрузка…
Ссылка в новой задаче