This commit is contained in:
Родитель
915455c730
Коммит
1ce9ddefdf
|
@ -3,33 +3,59 @@
|
||||||
# for full license information.
|
# for full license information.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
def dfs_walk(node, visitor, accum):
|
def dfs_walk(node, visitor, accum, visited):
|
||||||
|
'''
|
||||||
|
Generic function to walk the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (graph node): the node to start the journey from
|
||||||
|
visitor (Python function or lambda): function that takes a node as
|
||||||
|
argument and returns `True` if that node should be returned.
|
||||||
|
accum (`list`): accumulator of nodes while traversing the graph
|
||||||
|
visited (`set`): set of nodes that have already been visited.
|
||||||
|
Initialize with empty set.
|
||||||
|
'''
|
||||||
|
if node in visited:
|
||||||
|
return
|
||||||
|
visited.add(node)
|
||||||
if hasattr(node, 'root_function'):
|
if hasattr(node, 'root_function'):
|
||||||
node = node.root_function
|
node = node.root_function
|
||||||
for child in node.inputs:
|
for child in node.inputs:
|
||||||
dfs_walk(child, visitor, accum)
|
dfs_walk(child, visitor, accum, visited)
|
||||||
elif hasattr(node, 'is_output') and node.is_output:
|
elif hasattr(node, 'is_output') and node.is_output:
|
||||||
dfs_walk(node.owner, visitor, accum)
|
dfs_walk(node.owner, visitor, accum, visited)
|
||||||
|
|
||||||
if visitor(node):
|
if visitor(node):
|
||||||
accum.append(node)
|
accum.append(node)
|
||||||
|
|
||||||
def visit(root_node, visitor):
|
def visit(node, visitor):
|
||||||
nodes = []
|
|
||||||
dfs_walk(root_node, visitor, nodes)
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
def find_nodes_by_name(root_node, node_name):
|
|
||||||
'''
|
'''
|
||||||
Return a list of nodes having the given name
|
Generic function that walks through the graph starting at `node` and
|
||||||
|
applies function `visitor` on each of those.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
root_node (node in the graph): root node from where the search should
|
node (graph node): the node to start the journey from
|
||||||
start
|
visitor (Python function or lambda): function that takes a node as
|
||||||
node_name (`str`): name of the nodes
|
argument and returns `True` if that node should be returned.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of nodes having the given name
|
List of nodes, for which `visitor` was `True`
|
||||||
|
|
||||||
'''
|
'''
|
||||||
return visit(root_node, lambda x: x.name == node_name)
|
nodes = []
|
||||||
|
dfs_walk(node, visitor, nodes, set())
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def find_nodes_by_name(node, node_name):
|
||||||
|
'''
|
||||||
|
Finds nodes in the graph starting from `node` and doing a depth-first
|
||||||
|
search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (graph node): the node to start the journey from
|
||||||
|
node_name (`str`): name for which we are search nodes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of nodes having the specified name
|
||||||
|
'''
|
||||||
|
return visit(node, lambda x: x.name == node_name)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..initializer import *
|
from ..initializer import *
|
||||||
from .. import parameter, input_variable, momentums_per_sample
|
from .. import parameter
|
||||||
|
|
||||||
|
|
||||||
def _check(init, name):
|
def _check(init, name):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче