diff --git a/bindings/python/cntk/layers/sequence.py b/bindings/python/cntk/layers/sequence.py index 3102b2e7d..954123fa7 100644 --- a/bindings/python/cntk/layers/sequence.py +++ b/bindings/python/cntk/layers/sequence.py @@ -162,6 +162,34 @@ def _sanitize_function(f): return f +def _santize_step_function(f): + import types + from cntk.internal.utils import get_python_function_arguments + if isinstance(f, types.FunctionType): + py_args, _ = get_python_function_arguments(f) + try: + cntk_f, cntk_args = Function._to_Function_unchecked(f) + if len(cntk_f.arguments) > len(py_args): + cntk_args = [v.name for v in cntk_f.arguments] + additional_cntk_args = set(cntk_args) - set(py_args) + raise TypeError(('Recurrence Python step function makes use of additional CNTK variables or placeholders: {}. ' + 'Your step function arguments in Python code are: {}, ' + 'while the converted CNTK function argument are: {}. ' + 'This is currently not a supported Python step function definition. ' + 'Note that the current supported Python step function signature is: ' + 'step_function(prev_state_1, prev_state_2, ..., prev_state_n, sequence_input_x) -> next_state_1, next_state_2, ..., next_state_n ' + 'in which no references to any CNTK variables or placeholders are allowed.' + ).format(additional_cntk_args, py_args, cntk_args)) + f = Function._sanitize_check_Function(cntk_f, cntk_args, f) + except TypeError as e: + if str(e) != 'parameters cannot be created inside a @Function def': + raise + else: + raise TypeError('Parameter cannot be created inside Recurrence Python step function.') + + return f + + # TODO: allow to say sequential=False, axis=2, length=100, ... something like this def RecurrenceFrom(step_function, go_backwards=default_override_or(False), return_full_state=False, name=''): ''' @@ -228,7 +256,7 @@ def RecurrenceFrom(step_function, go_backwards=default_override_or(False), retur go_backwards = get_default_override(RecurrenceFrom, go_backwards=go_backwards) - step_function = _sanitize_function(step_function) + step_function = _santize_step_function(step_function) # get signature of step function #*prev_state_args, _ = step_function.signature # Python 3 @@ -392,7 +420,7 @@ def Recurrence(step_function, go_backwards=default_override_or(False), initial_s initial_state = get_default_override(Recurrence, initial_state=initial_state) initial_state = _get_initial_state_or_default(initial_state) - step_function = _sanitize_function(step_function) + step_function = _santize_step_function(step_function) # get signature of step function #*prev_state_args, _ = step_function.signature # Python 3 diff --git a/bindings/python/cntk/layers/tests/layers_test.py b/bindings/python/cntk/layers/tests/layers_test.py index 7875e7b4a..758bc4be8 100644 --- a/bindings/python/cntk/layers/tests/layers_test.py +++ b/bindings/python/cntk/layers/tests/layers_test.py @@ -162,6 +162,38 @@ def test_recurrence(): rt = FF(s, x) np.testing.assert_array_almost_equal(rt[0], exp, decimal=6, err_msg='Error in RecurrenceFrom(GRU()) forward') +def test_recurrence_step_fun(): + import cntk as C + def step_f(prev1, x): + return prev1 * x + rec = Recurrence(step_f) + + def step_f(prev1, prev2, x): + return prev1 * prev2 * x, prev1 * x + rec = Recurrence(step_f) + + def step_f(prev1, prev2, prev3, x): + return prev1 * prev2 * prev3 * x, prev1 * x, prev2 * x + rec = Recurrence(step_f) + + with pytest.raises(ValueError): + def step_f(prev1, prev2, prev3, prev4, x): + return prev1 * prev2 * prev3 * x, prev1 * x, prev2 * x, prev4 * x + rec = Recurrence(step_f) + + with pytest.raises(TypeError): + v = C.input_variable((1), name='additional_input_variable') + step_f = lambda prev, x: prev * v * x + rec = Recurrence(step_f) + + with pytest.raises(TypeError): + def step_f(prev1, x): + p = C.Parameter((1)) + return prev1 * x * p + rec = Recurrence(step_f) + + + #################################### # recurrence (Fold()) over regular function #################################### diff --git a/bindings/python/cntk/ops/functions.py b/bindings/python/cntk/ops/functions.py index 08fbfa1fe..b559114d8 100644 --- a/bindings/python/cntk/ops/functions.py +++ b/bindings/python/cntk/ops/functions.py @@ -142,7 +142,7 @@ class Function(cntk_py.Function): _placeholders_under_construction = set() @staticmethod - def _to_Function(f, make_block=False, op_name=None, name=None): + def _to_Function_unchecked(f, make_block=False, op_name=None, name=None): '''implements @Function decorator; see :class:`~cntk.layers.functions.Function`''' f_name = f.__name__ # (only used for debugging and error messages) @@ -242,21 +242,42 @@ class Function(cntk_py.Function): fun_args = force_order_args(fun_args) out = invoke(fun_args) - # verify that we got the parameter order right - out_arg_names = [arg.name for arg in out.signature] - assert out_arg_names == arg_names, (out_arg_names, arg_names) + return out, args - if len(out.signature) != len(args): - unfulfilled_args = set(out.signature) - set(args) - if unfulfilled_args: - unfulfilled_arg_names = [arg.name for arg in unfulfilled_args] - raise TypeError("CNTK Function '{}' has {} missing arguments ({}), which is currently not supported".format(f_name, len(unfulfilled_arg_names), ", ".join(unfulfilled_arg_names))) - else: - unused_args = set(args) - set(out.signature) - unused_arg_names = [arg.name for arg in unused_args] - raise TypeError("CNTK Function '{}' has {} unused arguments ({}), which is currently not supported".format(f_name, len(unused_arg_names), ", ".join(unused_arg_names))) + @staticmethod + def _sanitize_check_Function(f_out, f_args, f): + arg_names, annotations = get_python_function_arguments(f) + #verify the argument length first + if len(f_out.signature) != len(f_args): + f_name = f.__name__ + unfulfilled_args = set(f_out.signature) - set(f_args) + if unfulfilled_args: + unfulfilled_arg_names = [arg.name for arg in unfulfilled_args] + raise TypeError( + "CNTK Function '{}' has {} missing arguments ({}), which is currently not supported".format(f_name, + len( + unfulfilled_arg_names), + ", ".join( + unfulfilled_arg_names))) + else: + unused_args = set(f_args) - set(f_out.signature) + unused_arg_names = [arg.name for arg in unused_args] + raise TypeError( + "CNTK Function '{}' has {} unused arguments ({}), which is currently not supported".format(f_name, + len( + unused_arg_names), + ", ".join( + unused_arg_names))) - return out + #then verify that we got the parameter order right + out_arg_names = [arg.name for arg in f_out.signature] + assert out_arg_names == arg_names, (out_arg_names, arg_names) + return f_out + + @staticmethod + def _to_Function(f, make_block=False, op_name=None, name=None): + out, args = Function._to_Function_unchecked(f, make_block, op_name, name) + return Function._sanitize_check_Function(out, args, f) @property def signature(self):