Added sanitary check for step function signature with appropriate error messages.

This commit is contained in:
Yuqing Tang 2018-03-13 14:38:57 -07:00
Родитель 44c626a483
Коммит edbda9d7a5
3 изменённых файлов: 97 добавлений и 16 удалений

Просмотреть файл

@ -162,6 +162,34 @@ def _sanitize_function(f):
return f
def _santize_step_function(f):
import types
from cntk.internal.utils import get_python_function_arguments
if isinstance(f, types.FunctionType):
py_args, _ = get_python_function_arguments(f)
try:
cntk_f, cntk_args = Function._to_Function_unchecked(f)
if len(cntk_f.arguments) > len(py_args):
cntk_args = [v.name for v in cntk_f.arguments]
additional_cntk_args = set(cntk_args) - set(py_args)
raise TypeError(('Recurrence Python step function makes use of additional CNTK variables or placeholders: {}. '
'Your step function arguments in Python code are: {}, '
'while the converted CNTK function argument are: {}. '
'This is currently not a supported Python step function definition. '
'Note that the current supported Python step function signature is: '
'step_function(prev_state_1, prev_state_2, ..., prev_state_n, sequence_input_x) -> next_state_1, next_state_2, ..., next_state_n '
'in which no references to any CNTK variables or placeholders are allowed.'
).format(additional_cntk_args, py_args, cntk_args))
f = Function._sanitize_check_Function(cntk_f, cntk_args, f)
except TypeError as e:
if str(e) != 'parameters cannot be created inside a @Function def':
raise
else:
raise TypeError('Parameter cannot be created inside Recurrence Python step function.')
return f
# TODO: allow to say sequential=False, axis=2, length=100, ... something like this
def RecurrenceFrom(step_function, go_backwards=default_override_or(False), return_full_state=False, name=''):
'''
@ -228,7 +256,7 @@ def RecurrenceFrom(step_function, go_backwards=default_override_or(False), retur
go_backwards = get_default_override(RecurrenceFrom, go_backwards=go_backwards)
step_function = _sanitize_function(step_function)
step_function = _santize_step_function(step_function)
# get signature of step function
#*prev_state_args, _ = step_function.signature # Python 3
@ -392,7 +420,7 @@ def Recurrence(step_function, go_backwards=default_override_or(False), initial_s
initial_state = get_default_override(Recurrence, initial_state=initial_state)
initial_state = _get_initial_state_or_default(initial_state)
step_function = _sanitize_function(step_function)
step_function = _santize_step_function(step_function)
# get signature of step function
#*prev_state_args, _ = step_function.signature # Python 3

Просмотреть файл

@ -162,6 +162,38 @@ def test_recurrence():
rt = FF(s, x)
np.testing.assert_array_almost_equal(rt[0], exp, decimal=6, err_msg='Error in RecurrenceFrom(GRU()) forward')
def test_recurrence_step_fun():
import cntk as C
def step_f(prev1, x):
return prev1 * x
rec = Recurrence(step_f)
def step_f(prev1, prev2, x):
return prev1 * prev2 * x, prev1 * x
rec = Recurrence(step_f)
def step_f(prev1, prev2, prev3, x):
return prev1 * prev2 * prev3 * x, prev1 * x, prev2 * x
rec = Recurrence(step_f)
with pytest.raises(ValueError):
def step_f(prev1, prev2, prev3, prev4, x):
return prev1 * prev2 * prev3 * x, prev1 * x, prev2 * x, prev4 * x
rec = Recurrence(step_f)
with pytest.raises(TypeError):
v = C.input_variable((1), name='additional_input_variable')
step_f = lambda prev, x: prev * v * x
rec = Recurrence(step_f)
with pytest.raises(TypeError):
def step_f(prev1, x):
p = C.Parameter((1))
return prev1 * x * p
rec = Recurrence(step_f)
####################################
# recurrence (Fold()) over regular function
####################################

Просмотреть файл

@ -142,7 +142,7 @@ class Function(cntk_py.Function):
_placeholders_under_construction = set()
@staticmethod
def _to_Function(f, make_block=False, op_name=None, name=None):
def _to_Function_unchecked(f, make_block=False, op_name=None, name=None):
'''implements @Function decorator; see :class:`~cntk.layers.functions.Function`'''
f_name = f.__name__ # (only used for debugging and error messages)
@ -242,21 +242,42 @@ class Function(cntk_py.Function):
fun_args = force_order_args(fun_args)
out = invoke(fun_args)
# verify that we got the parameter order right
out_arg_names = [arg.name for arg in out.signature]
assert out_arg_names == arg_names, (out_arg_names, arg_names)
return out, args
if len(out.signature) != len(args):
unfulfilled_args = set(out.signature) - set(args)
if unfulfilled_args:
unfulfilled_arg_names = [arg.name for arg in unfulfilled_args]
raise TypeError("CNTK Function '{}' has {} missing arguments ({}), which is currently not supported".format(f_name, len(unfulfilled_arg_names), ", ".join(unfulfilled_arg_names)))
else:
unused_args = set(args) - set(out.signature)
unused_arg_names = [arg.name for arg in unused_args]
raise TypeError("CNTK Function '{}' has {} unused arguments ({}), which is currently not supported".format(f_name, len(unused_arg_names), ", ".join(unused_arg_names)))
@staticmethod
def _sanitize_check_Function(f_out, f_args, f):
arg_names, annotations = get_python_function_arguments(f)
#verify the argument length first
if len(f_out.signature) != len(f_args):
f_name = f.__name__
unfulfilled_args = set(f_out.signature) - set(f_args)
if unfulfilled_args:
unfulfilled_arg_names = [arg.name for arg in unfulfilled_args]
raise TypeError(
"CNTK Function '{}' has {} missing arguments ({}), which is currently not supported".format(f_name,
len(
unfulfilled_arg_names),
", ".join(
unfulfilled_arg_names)))
else:
unused_args = set(f_args) - set(f_out.signature)
unused_arg_names = [arg.name for arg in unused_args]
raise TypeError(
"CNTK Function '{}' has {} unused arguments ({}), which is currently not supported".format(f_name,
len(
unused_arg_names),
", ".join(
unused_arg_names)))
return out
#then verify that we got the parameter order right
out_arg_names = [arg.name for arg in f_out.signature]
assert out_arg_names == arg_names, (out_arg_names, arg_names)
return f_out
@staticmethod
def _to_Function(f, make_block=False, op_name=None, name=None):
out, args = Function._to_Function_unchecked(f, make_block, op_name, name)
return Function._sanitize_check_Function(out, args, f)
@property
def signature(self):