This commit is contained in:
jeanfad 2016-04-05 14:08:34 +02:00
Родитель f2dd02b029
Коммит b8cd9a1279
2 изменённых файлов: 23 добавлений и 12 удалений

Просмотреть файл

@ -285,14 +285,14 @@ class AbstractContext(object, metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def eval(self, node, reader=None, node_unit_test=False): def eval(self, node, reader=None, backward_pass=False):
''' '''
Abstract method for the action write. It evaluated the passed node on the Abstract method for the action write. It evaluated the passed node on the
data provided by the reader. data provided by the reader.
:param node: the node to evaluate. :param node: the node to evaluate.
:param reader: the reader to use for this action. Alternatively, you :param reader: the reader to use for this action. Alternatively, you
can attach a reader directly to the input node. can attach a reader directly to the input node.
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass) :param backward_pass: set to True if you want to output the gradient of a node (backward pass)
Returns the output generated by `node` Returns the output generated by `node`
''' '''
pass pass
@ -543,7 +543,7 @@ class Context(AbstractContext):
return expected_shape, expected_size return expected_shape, expected_size
def eval(self, node, reader=None, node_unit_test=False, input_name=None): def eval(self, node, reader=None, backward_pass=False, input_name=None):
''' '''
Run the write action locally to evaluate the passed node and returning Run the write action locally to evaluate the passed node and returning
the data it produced. the data it produced.
@ -551,8 +551,8 @@ class Context(AbstractContext):
:param node: the node to evaluate. :param node: the node to evaluate.
:param reader: the reader to use for this action. Alternatively, you :param reader: the reader to use for this action. Alternatively, you
can attach a reader directly to the input node. can attach a reader directly to the input node.
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass) :param backward_pass: set to True if you want to output the gradient of a node (backward pass)
:input_name: if node_unit_test is True then input_node should contain the input name that :input_name: if backward_pass is True then input_node should contain the input name that
the gradient is performed with respect to. the gradient is performed with respect to.
Returns the output generated by `node` Returns the output generated by `node`
''' '''
@ -564,15 +564,15 @@ class Context(AbstractContext):
orig_node_tag = node.tag if hasattr(node, 'tag') else None orig_node_tag = node.tag if hasattr(node, 'tag') else None
node.tag = 'output' node.tag = 'output'
config_content = self._generate_eval_config(node, reader, node_unit_test) config_content = self._generate_eval_config(node, reader, backward_pass)
output = self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content) self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
node.tag = orig_node_tag node.tag = orig_node_tag
n = input_name.var_name if isinstance(input_name, ComputationNode) else input_name n = input_name.var_name if isinstance(input_name, ComputationNode) else input_name
out_name = os.path.join( out_name = os.path.join(
self.directory, CNTK_OUTPUT_FILENAME + '.' + \ self.directory, CNTK_OUTPUT_FILENAME + '.' + \
((n + '.grad') if node_unit_test else node.var_name)) ((n + '.grad') if backward_pass else node.var_name))
result_content = open(out_name).read() result_content = open(out_name).read()
data = Context._parse_result_output(result_content) data = Context._parse_result_output(result_content)

Просмотреть файл

@ -11,14 +11,17 @@ C = constant
I = input I = input
AA = np.asarray AA = np.asarray
def _test(root_node, expected, clean_up=True): def _test(root_node, expected, clean_up=True, backward_pass = False, input_node = None):
with get_new_context() as ctx: with get_new_context() as ctx:
ctx.clean_up = clean_up ctx.clean_up = clean_up
assert not ctx.input_nodes assert not ctx.input_nodes
result = ctx.eval(root_node) result = ctx.eval(root_node, None, backward_pass, input_node)
assert len(result) == len(expected) assert len(result) == len(expected)
for res, exp in zip(result, expected): for res, exp in zip(result, expected):
print(res)
print(exp)
print("===========#####")
assert np.allclose(res, exp) assert np.allclose(res, exp)
assert res.shape == AA(exp).shape assert res.shape == AA(exp).shape
@ -84,9 +87,17 @@ def test_op_add_input_constant(left_arg, right_arg):
], ],
2), 2),
]) ])
def test_op_mul_input_seq(left_arg, right_arg): def test_op_mul_input_seq(left_arg, right_arg):
expected = [AA(elem)*right_arg for elem in left_arg] expected = [AA(elem)*right_arg for elem in left_arg]
result = I(left_arg, has_sequence_dimension=True) * right_arg result = I(left_arg, has_sequence_dimension=True) * right_arg
_test(result, expected, False) _test(result, expected, False)
@pytest.mark.parametrize("left_arg, right_arg", [
([0],[0]), # grad(Cos(0)) = -Sin(0) = 0
([1.57079633],[-1]), # grad(Cos(pi/2)) = -Sin(pi/2) = -1
([3.14159265],[0]), # grad(Cos(pi)) = -Sin(pi) = 0
])
def test_cosine_backward(left_arg, right_arg):
i = I([left_arg], has_sequence_dimension=False)
n = Cosine(i)
_test(n, [[AA(right_arg)]], clean_up=False, backward_pass = True, input_node = i)