This commit is contained in:
jeanfad 2016-04-05 14:08:34 +02:00
Родитель f2dd02b029
Коммит b8cd9a1279
2 изменённых файлов: 23 добавлений и 12 удалений

Просмотреть файл

@ -285,14 +285,14 @@ class AbstractContext(object, metaclass=ABCMeta):
pass
@abstractmethod
def eval(self, node, reader=None, node_unit_test=False):
def eval(self, node, reader=None, backward_pass=False):
'''
Abstract method for the action write. It evaluated the passed node on the
data provided by the reader.
:param node: the node to evaluate.
:param reader: the reader to use for this action. Alternatively, you
can attach a reader directly to the input node.
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
:param backward_pass: set to True if you want to output the gradient of a node (backward pass)
Returns the output generated by `node`
'''
pass
@ -543,7 +543,7 @@ class Context(AbstractContext):
return expected_shape, expected_size
def eval(self, node, reader=None, node_unit_test=False, input_name=None):
def eval(self, node, reader=None, backward_pass=False, input_name=None):
'''
Run the write action locally to evaluate the passed node and returning
the data it produced.
@ -551,8 +551,8 @@ class Context(AbstractContext):
:param node: the node to evaluate.
:param reader: the reader to use for this action. Alternatively, you
can attach a reader directly to the input node.
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
:input_name: if node_unit_test is True then input_node should contain the input name that
:param backward_pass: set to True if you want to output the gradient of a node (backward pass)
:input_name: if backward_pass is True then input_node should contain the input name that
the gradient is performed with respect to.
Returns the output generated by `node`
'''
@ -564,15 +564,15 @@ class Context(AbstractContext):
orig_node_tag = node.tag if hasattr(node, 'tag') else None
node.tag = 'output'
config_content = self._generate_eval_config(node, reader, node_unit_test)
output = self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
config_content = self._generate_eval_config(node, reader, backward_pass)
self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
node.tag = orig_node_tag
n = input_name.var_name if isinstance(input_name, ComputationNode) else input_name
out_name = os.path.join(
self.directory, CNTK_OUTPUT_FILENAME + '.' + \
((n + '.grad') if node_unit_test else node.var_name))
((n + '.grad') if backward_pass else node.var_name))
result_content = open(out_name).read()
data = Context._parse_result_output(result_content)

Просмотреть файл

@ -11,14 +11,17 @@ C = constant
I = input
AA = np.asarray
def _test(root_node, expected, clean_up=True):
def _test(root_node, expected, clean_up=True, backward_pass = False, input_node = None):
with get_new_context() as ctx:
ctx.clean_up = clean_up
assert not ctx.input_nodes
result = ctx.eval(root_node)
result = ctx.eval(root_node, None, backward_pass, input_node)
assert len(result) == len(expected)
for res, exp in zip(result, expected):
for res, exp in zip(result, expected):
print(res)
print(exp)
print("===========#####")
assert np.allclose(res, exp)
assert res.shape == AA(exp).shape
@ -84,9 +87,17 @@ def test_op_add_input_constant(left_arg, right_arg):
],
2),
])
def test_op_mul_input_seq(left_arg, right_arg):
expected = [AA(elem)*right_arg for elem in left_arg]
result = I(left_arg, has_sequence_dimension=True) * right_arg
_test(result, expected, False)
@pytest.mark.parametrize("left_arg, right_arg", [
([0],[0]), # grad(Cos(0)) = -Sin(0) = 0
([1.57079633],[-1]), # grad(Cos(pi/2)) = -Sin(pi/2) = -1
([3.14159265],[0]), # grad(Cos(pi)) = -Sin(pi) = 0
])
def test_cosine_backward(left_arg, right_arg):
i = I([left_arg], has_sequence_dimension=False)
n = Cosine(i)
_test(n, [[AA(right_arg)]], clean_up=False, backward_pass = True, input_node = i)