Restricting loops to Delay, PastValue, and FutureValue

This commit is contained in:
Willi Richert 2016-03-30 16:59:49 +02:00
Родитель 7e03dfb8cc
Коммит 21cd2376d1
1 изменённых файлов: 14 добавлений и 3 удалений

Просмотреть файл

@ -140,6 +140,17 @@ class ComputationNode(object):
return param
def _is_forward_ref(self, p_name, p_value):
'''
Although the unrolled graph is a DAG, when we specify recurrence we
naturally have loops. We can resolve this by using forward references.
This method is checking whether the particular name and value of this
instance are actually one of those forward references.
'''
is_loop_node = self.name in ('Delay', 'PastValue', 'FutureValue')
return is_loop_node and p_name == 'input' and isinstance(p_value, str)
def _to_config_recursively(self, desc, unrolled_nodes, inputs,
readers, node_counter=0):
param_variable_names = []
@ -161,8 +172,8 @@ class ComputationNode(object):
input_nodes_vars = []
for pv in inputs_param:
if pv in unrolled_nodes:
# we have seen this node already, so just retrieve its
# name
# We have seen this node already, so just retrieve its
# name.
child_var = unrolled_nodes[pv]
else:
child_var, node_counter, child_desc = pv._to_config_recursively(
@ -173,7 +184,7 @@ class ComputationNode(object):
param_variable_names.append(
_tuple_to_cntk_shape(input_nodes_vars))
else:
if p_name == 'input' and isinstance(p_value, str):
if self._is_forward_ref(p_name, p_value):
# We have a forward reference to a node that will be
# later on defined. p_value is the var_name of the
# later defined node.