Restricting loops to Delay, PastValue, and FutureValue
This commit is contained in:
Родитель
7e03dfb8cc
Коммит
21cd2376d1
|
@ -140,6 +140,17 @@ class ComputationNode(object):
|
|||
|
||||
return param
|
||||
|
||||
|
||||
def _is_forward_ref(self, p_name, p_value):
|
||||
'''
|
||||
Although the unrolled graph is a DAG, when we specify recurrence we
|
||||
naturally have loops. We can resolve this by using forward references.
|
||||
This method is checking whether the particular name and value of this
|
||||
instance are actually one of those forward references.
|
||||
'''
|
||||
is_loop_node = self.name in ('Delay', 'PastValue', 'FutureValue')
|
||||
return is_loop_node and p_name == 'input' and isinstance(p_value, str)
|
||||
|
||||
def _to_config_recursively(self, desc, unrolled_nodes, inputs,
|
||||
readers, node_counter=0):
|
||||
param_variable_names = []
|
||||
|
@ -161,8 +172,8 @@ class ComputationNode(object):
|
|||
input_nodes_vars = []
|
||||
for pv in inputs_param:
|
||||
if pv in unrolled_nodes:
|
||||
# we have seen this node already, so just retrieve its
|
||||
# name
|
||||
# We have seen this node already, so just retrieve its
|
||||
# name.
|
||||
child_var = unrolled_nodes[pv]
|
||||
else:
|
||||
child_var, node_counter, child_desc = pv._to_config_recursively(
|
||||
|
@ -173,7 +184,7 @@ class ComputationNode(object):
|
|||
param_variable_names.append(
|
||||
_tuple_to_cntk_shape(input_nodes_vars))
|
||||
else:
|
||||
if p_name == 'input' and isinstance(p_value, str):
|
||||
if self._is_forward_ref(p_name, p_value):
|
||||
# We have a forward reference to a node that will be
|
||||
# later on defined. p_value is the var_name of the
|
||||
# later defined node.
|
||||
|
|
Загрузка…
Ссылка в новой задаче