From 21cd2376d1b63a378e46a0098a85dd2482132dd7 Mon Sep 17 00:00:00 2001 From: Willi Richert Date: Wed, 30 Mar 2016 16:59:49 +0200 Subject: [PATCH] Restricting loops to Delay, PastValue, and FutureValue --- contrib/Python/cntk/graph.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/contrib/Python/cntk/graph.py b/contrib/Python/cntk/graph.py index 53b5d6875..367a13fd8 100644 --- a/contrib/Python/cntk/graph.py +++ b/contrib/Python/cntk/graph.py @@ -140,6 +140,17 @@ class ComputationNode(object): return param + + def _is_forward_ref(self, p_name, p_value): + ''' + Although the unrolled graph is a DAG, when we specify recurrence we + naturally have loops. We can resolve this by using forward references. + This method is checking whether the particular name and value of this + instance are actually one of those forward references. + ''' + is_loop_node = self.name in ('Delay', 'PastValue', 'FutureValue') + return is_loop_node and p_name == 'input' and isinstance(p_value, str) + def _to_config_recursively(self, desc, unrolled_nodes, inputs, readers, node_counter=0): param_variable_names = [] @@ -161,8 +172,8 @@ class ComputationNode(object): input_nodes_vars = [] for pv in inputs_param: if pv in unrolled_nodes: - # we have seen this node already, so just retrieve its - # name + # We have seen this node already, so just retrieve its + # name. child_var = unrolled_nodes[pv] else: child_var, node_counter, child_desc = pv._to_config_recursively( @@ -173,7 +184,7 @@ class ComputationNode(object): param_variable_names.append( _tuple_to_cntk_shape(input_nodes_vars)) else: - if p_name == 'input' and isinstance(p_value, str): + if self._is_forward_ref(p_name, p_value): # We have a forward reference to a node that will be # later on defined. p_value is the var_name of the # later defined node.