first gradient unittest
This commit is contained in:
Родитель
f2dd02b029
Коммит
b8cd9a1279
|
@ -285,14 +285,14 @@ class AbstractContext(object, metaclass=ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def eval(self, node, reader=None, node_unit_test=False):
|
def eval(self, node, reader=None, backward_pass=False):
|
||||||
'''
|
'''
|
||||||
Abstract method for the action write. It evaluated the passed node on the
|
Abstract method for the action write. It evaluated the passed node on the
|
||||||
data provided by the reader.
|
data provided by the reader.
|
||||||
:param node: the node to evaluate.
|
:param node: the node to evaluate.
|
||||||
:param reader: the reader to use for this action. Alternatively, you
|
:param reader: the reader to use for this action. Alternatively, you
|
||||||
can attach a reader directly to the input node.
|
can attach a reader directly to the input node.
|
||||||
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
|
:param backward_pass: set to True if you want to output the gradient of a node (backward pass)
|
||||||
Returns the output generated by `node`
|
Returns the output generated by `node`
|
||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
@ -543,7 +543,7 @@ class Context(AbstractContext):
|
||||||
|
|
||||||
return expected_shape, expected_size
|
return expected_shape, expected_size
|
||||||
|
|
||||||
def eval(self, node, reader=None, node_unit_test=False, input_name=None):
|
def eval(self, node, reader=None, backward_pass=False, input_name=None):
|
||||||
'''
|
'''
|
||||||
Run the write action locally to evaluate the passed node and returning
|
Run the write action locally to evaluate the passed node and returning
|
||||||
the data it produced.
|
the data it produced.
|
||||||
|
@ -551,8 +551,8 @@ class Context(AbstractContext):
|
||||||
:param node: the node to evaluate.
|
:param node: the node to evaluate.
|
||||||
:param reader: the reader to use for this action. Alternatively, you
|
:param reader: the reader to use for this action. Alternatively, you
|
||||||
can attach a reader directly to the input node.
|
can attach a reader directly to the input node.
|
||||||
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
|
:param backward_pass: set to True if you want to output the gradient of a node (backward pass)
|
||||||
:input_name: if node_unit_test is True then input_node should contain the input name that
|
:input_name: if backward_pass is True then input_node should contain the input name that
|
||||||
the gradient is performed with respect to.
|
the gradient is performed with respect to.
|
||||||
Returns the output generated by `node`
|
Returns the output generated by `node`
|
||||||
'''
|
'''
|
||||||
|
@ -564,15 +564,15 @@ class Context(AbstractContext):
|
||||||
orig_node_tag = node.tag if hasattr(node, 'tag') else None
|
orig_node_tag = node.tag if hasattr(node, 'tag') else None
|
||||||
node.tag = 'output'
|
node.tag = 'output'
|
||||||
|
|
||||||
config_content = self._generate_eval_config(node, reader, node_unit_test)
|
config_content = self._generate_eval_config(node, reader, backward_pass)
|
||||||
output = self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
|
self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
|
||||||
|
|
||||||
node.tag = orig_node_tag
|
node.tag = orig_node_tag
|
||||||
|
|
||||||
n = input_name.var_name if isinstance(input_name, ComputationNode) else input_name
|
n = input_name.var_name if isinstance(input_name, ComputationNode) else input_name
|
||||||
out_name = os.path.join(
|
out_name = os.path.join(
|
||||||
self.directory, CNTK_OUTPUT_FILENAME + '.' + \
|
self.directory, CNTK_OUTPUT_FILENAME + '.' + \
|
||||||
((n + '.grad') if node_unit_test else node.var_name))
|
((n + '.grad') if backward_pass else node.var_name))
|
||||||
|
|
||||||
result_content = open(out_name).read()
|
result_content = open(out_name).read()
|
||||||
data = Context._parse_result_output(result_content)
|
data = Context._parse_result_output(result_content)
|
||||||
|
|
|
@ -11,14 +11,17 @@ C = constant
|
||||||
I = input
|
I = input
|
||||||
AA = np.asarray
|
AA = np.asarray
|
||||||
|
|
||||||
def _test(root_node, expected, clean_up=True):
|
def _test(root_node, expected, clean_up=True, backward_pass = False, input_node = None):
|
||||||
with get_new_context() as ctx:
|
with get_new_context() as ctx:
|
||||||
ctx.clean_up = clean_up
|
ctx.clean_up = clean_up
|
||||||
assert not ctx.input_nodes
|
assert not ctx.input_nodes
|
||||||
result = ctx.eval(root_node)
|
result = ctx.eval(root_node, None, backward_pass, input_node)
|
||||||
|
|
||||||
assert len(result) == len(expected)
|
assert len(result) == len(expected)
|
||||||
for res, exp in zip(result, expected):
|
for res, exp in zip(result, expected):
|
||||||
|
print(res)
|
||||||
|
print(exp)
|
||||||
|
print("===========#####")
|
||||||
assert np.allclose(res, exp)
|
assert np.allclose(res, exp)
|
||||||
assert res.shape == AA(exp).shape
|
assert res.shape == AA(exp).shape
|
||||||
|
|
||||||
|
@ -84,9 +87,17 @@ def test_op_add_input_constant(left_arg, right_arg):
|
||||||
],
|
],
|
||||||
2),
|
2),
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_op_mul_input_seq(left_arg, right_arg):
|
def test_op_mul_input_seq(left_arg, right_arg):
|
||||||
expected = [AA(elem)*right_arg for elem in left_arg]
|
expected = [AA(elem)*right_arg for elem in left_arg]
|
||||||
result = I(left_arg, has_sequence_dimension=True) * right_arg
|
result = I(left_arg, has_sequence_dimension=True) * right_arg
|
||||||
_test(result, expected, False)
|
_test(result, expected, False)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("left_arg, right_arg", [
|
||||||
|
([0],[0]), # grad(Cos(0)) = -Sin(0) = 0
|
||||||
|
([1.57079633],[-1]), # grad(Cos(pi/2)) = -Sin(pi/2) = -1
|
||||||
|
([3.14159265],[0]), # grad(Cos(pi)) = -Sin(pi) = 0
|
||||||
|
])
|
||||||
|
def test_cosine_backward(left_arg, right_arg):
|
||||||
|
i = I([left_arg], has_sequence_dimension=False)
|
||||||
|
n = Cosine(i)
|
||||||
|
_test(n, [[AA(right_arg)]], clean_up=False, backward_pass = True, input_node = i)
|
Загрузка…
Ссылка в новой задаче