This commit is contained in:
jeanfad 2016-04-06 10:26:06 +02:00
Родитель b8cd9a1279
Коммит d14daf45f2
1 изменённых файлов: 29 добавлений и 2 удалений

Просмотреть файл

@ -18,10 +18,10 @@ def _test(root_node, expected, clean_up=True, backward_pass = False, input_node
result = ctx.eval(root_node, None, backward_pass, input_node)
assert len(result) == len(expected)
for res, exp in zip(result, expected):
for res, exp in zip(result, expected):
print ("asdf")
print(res)
print(exp)
print("===========#####")
assert np.allclose(res, exp)
assert res.shape == AA(exp).shape
@ -92,6 +92,33 @@ def test_op_mul_input_seq(left_arg, right_arg):
result = I(left_arg, has_sequence_dimension=True) * right_arg
_test(result, expected, False)
@pytest.mark.parametrize("left_arg, right_arg", [
#([30, 2], [10, 3]),
([[30,40], [1,2]], [[30,40], [1,2]]),
])
def test_elemmul_backward(left_arg, right_arg):
expected = AA(right_arg)
# sequence of 1 element, since we have has_sequence_dimension=False
expected = [expected]
# batch of one sample
expected = [expected]
a = I([left_arg], has_sequence_dimension=False)
m = a * right_arg
_test(m, expected, clean_up=False, backward_pass = True, input_node = a)
expected = AA(left_arg)
# sequence of 1 element, since we have has_sequence_dimension=False
expected = [expected]
# batch of one sample
expected = [expected]
b = I([left_arg], has_sequence_dimension=False)
#b.var_name = 'v2'
m = left_arg * b
_test(m, expected, clean_up=False, backward_pass = True, input_node = b)
@pytest.mark.parametrize("left_arg, right_arg", [
([0],[0]), # grad(Cos(0)) = -Sin(0) = 0
([1.57079633],[-1]), # grad(Cos(pi/2)) = -Sin(pi/2) = -1