Added sanitary check for step function signature with appropriate error messages.
This commit is contained in:
Родитель
44c626a483
Коммит
edbda9d7a5
|
@ -162,6 +162,34 @@ def _sanitize_function(f):
|
|||
return f
|
||||
|
||||
|
||||
def _santize_step_function(f):
|
||||
import types
|
||||
from cntk.internal.utils import get_python_function_arguments
|
||||
if isinstance(f, types.FunctionType):
|
||||
py_args, _ = get_python_function_arguments(f)
|
||||
try:
|
||||
cntk_f, cntk_args = Function._to_Function_unchecked(f)
|
||||
if len(cntk_f.arguments) > len(py_args):
|
||||
cntk_args = [v.name for v in cntk_f.arguments]
|
||||
additional_cntk_args = set(cntk_args) - set(py_args)
|
||||
raise TypeError(('Recurrence Python step function makes use of additional CNTK variables or placeholders: {}. '
|
||||
'Your step function arguments in Python code are: {}, '
|
||||
'while the converted CNTK function argument are: {}. '
|
||||
'This is currently not a supported Python step function definition. '
|
||||
'Note that the current supported Python step function signature is: '
|
||||
'step_function(prev_state_1, prev_state_2, ..., prev_state_n, sequence_input_x) -> next_state_1, next_state_2, ..., next_state_n '
|
||||
'in which no references to any CNTK variables or placeholders are allowed.'
|
||||
).format(additional_cntk_args, py_args, cntk_args))
|
||||
f = Function._sanitize_check_Function(cntk_f, cntk_args, f)
|
||||
except TypeError as e:
|
||||
if str(e) != 'parameters cannot be created inside a @Function def':
|
||||
raise
|
||||
else:
|
||||
raise TypeError('Parameter cannot be created inside Recurrence Python step function.')
|
||||
|
||||
return f
|
||||
|
||||
|
||||
# TODO: allow to say sequential=False, axis=2, length=100, ... something like this
|
||||
def RecurrenceFrom(step_function, go_backwards=default_override_or(False), return_full_state=False, name=''):
|
||||
'''
|
||||
|
@ -228,7 +256,7 @@ def RecurrenceFrom(step_function, go_backwards=default_override_or(False), retur
|
|||
|
||||
go_backwards = get_default_override(RecurrenceFrom, go_backwards=go_backwards)
|
||||
|
||||
step_function = _sanitize_function(step_function)
|
||||
step_function = _santize_step_function(step_function)
|
||||
|
||||
# get signature of step function
|
||||
#*prev_state_args, _ = step_function.signature # Python 3
|
||||
|
@ -392,7 +420,7 @@ def Recurrence(step_function, go_backwards=default_override_or(False), initial_s
|
|||
initial_state = get_default_override(Recurrence, initial_state=initial_state)
|
||||
initial_state = _get_initial_state_or_default(initial_state)
|
||||
|
||||
step_function = _sanitize_function(step_function)
|
||||
step_function = _santize_step_function(step_function)
|
||||
|
||||
# get signature of step function
|
||||
#*prev_state_args, _ = step_function.signature # Python 3
|
||||
|
|
|
@ -162,6 +162,38 @@ def test_recurrence():
|
|||
rt = FF(s, x)
|
||||
np.testing.assert_array_almost_equal(rt[0], exp, decimal=6, err_msg='Error in RecurrenceFrom(GRU()) forward')
|
||||
|
||||
def test_recurrence_step_fun():
|
||||
import cntk as C
|
||||
def step_f(prev1, x):
|
||||
return prev1 * x
|
||||
rec = Recurrence(step_f)
|
||||
|
||||
def step_f(prev1, prev2, x):
|
||||
return prev1 * prev2 * x, prev1 * x
|
||||
rec = Recurrence(step_f)
|
||||
|
||||
def step_f(prev1, prev2, prev3, x):
|
||||
return prev1 * prev2 * prev3 * x, prev1 * x, prev2 * x
|
||||
rec = Recurrence(step_f)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
def step_f(prev1, prev2, prev3, prev4, x):
|
||||
return prev1 * prev2 * prev3 * x, prev1 * x, prev2 * x, prev4 * x
|
||||
rec = Recurrence(step_f)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
v = C.input_variable((1), name='additional_input_variable')
|
||||
step_f = lambda prev, x: prev * v * x
|
||||
rec = Recurrence(step_f)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
def step_f(prev1, x):
|
||||
p = C.Parameter((1))
|
||||
return prev1 * x * p
|
||||
rec = Recurrence(step_f)
|
||||
|
||||
|
||||
|
||||
####################################
|
||||
# recurrence (Fold()) over regular function
|
||||
####################################
|
||||
|
|
|
@ -142,7 +142,7 @@ class Function(cntk_py.Function):
|
|||
_placeholders_under_construction = set()
|
||||
|
||||
@staticmethod
|
||||
def _to_Function(f, make_block=False, op_name=None, name=None):
|
||||
def _to_Function_unchecked(f, make_block=False, op_name=None, name=None):
|
||||
'''implements @Function decorator; see :class:`~cntk.layers.functions.Function`'''
|
||||
f_name = f.__name__ # (only used for debugging and error messages)
|
||||
|
||||
|
@ -242,21 +242,42 @@ class Function(cntk_py.Function):
|
|||
fun_args = force_order_args(fun_args)
|
||||
out = invoke(fun_args)
|
||||
|
||||
# verify that we got the parameter order right
|
||||
out_arg_names = [arg.name for arg in out.signature]
|
||||
assert out_arg_names == arg_names, (out_arg_names, arg_names)
|
||||
return out, args
|
||||
|
||||
if len(out.signature) != len(args):
|
||||
unfulfilled_args = set(out.signature) - set(args)
|
||||
if unfulfilled_args:
|
||||
unfulfilled_arg_names = [arg.name for arg in unfulfilled_args]
|
||||
raise TypeError("CNTK Function '{}' has {} missing arguments ({}), which is currently not supported".format(f_name, len(unfulfilled_arg_names), ", ".join(unfulfilled_arg_names)))
|
||||
else:
|
||||
unused_args = set(args) - set(out.signature)
|
||||
unused_arg_names = [arg.name for arg in unused_args]
|
||||
raise TypeError("CNTK Function '{}' has {} unused arguments ({}), which is currently not supported".format(f_name, len(unused_arg_names), ", ".join(unused_arg_names)))
|
||||
@staticmethod
|
||||
def _sanitize_check_Function(f_out, f_args, f):
|
||||
arg_names, annotations = get_python_function_arguments(f)
|
||||
#verify the argument length first
|
||||
if len(f_out.signature) != len(f_args):
|
||||
f_name = f.__name__
|
||||
unfulfilled_args = set(f_out.signature) - set(f_args)
|
||||
if unfulfilled_args:
|
||||
unfulfilled_arg_names = [arg.name for arg in unfulfilled_args]
|
||||
raise TypeError(
|
||||
"CNTK Function '{}' has {} missing arguments ({}), which is currently not supported".format(f_name,
|
||||
len(
|
||||
unfulfilled_arg_names),
|
||||
", ".join(
|
||||
unfulfilled_arg_names)))
|
||||
else:
|
||||
unused_args = set(f_args) - set(f_out.signature)
|
||||
unused_arg_names = [arg.name for arg in unused_args]
|
||||
raise TypeError(
|
||||
"CNTK Function '{}' has {} unused arguments ({}), which is currently not supported".format(f_name,
|
||||
len(
|
||||
unused_arg_names),
|
||||
", ".join(
|
||||
unused_arg_names)))
|
||||
|
||||
return out
|
||||
#then verify that we got the parameter order right
|
||||
out_arg_names = [arg.name for arg in f_out.signature]
|
||||
assert out_arg_names == arg_names, (out_arg_names, arg_names)
|
||||
return f_out
|
||||
|
||||
@staticmethod
|
||||
def _to_Function(f, make_block=False, op_name=None, name=None):
|
||||
out, args = Function._to_Function_unchecked(f, make_block, op_name, name)
|
||||
return Function._sanitize_check_Function(out, args, f)
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче