elem times backward pass test
This commit is contained in:
Родитель
b8cd9a1279
Коммит
d14daf45f2
|
@ -18,10 +18,10 @@ def _test(root_node, expected, clean_up=True, backward_pass = False, input_node
|
|||
result = ctx.eval(root_node, None, backward_pass, input_node)
|
||||
|
||||
assert len(result) == len(expected)
|
||||
for res, exp in zip(result, expected):
|
||||
for res, exp in zip(result, expected):
|
||||
print ("asdf")
|
||||
print(res)
|
||||
print(exp)
|
||||
print("===========#####")
|
||||
assert np.allclose(res, exp)
|
||||
assert res.shape == AA(exp).shape
|
||||
|
||||
|
@ -92,6 +92,33 @@ def test_op_mul_input_seq(left_arg, right_arg):
|
|||
result = I(left_arg, has_sequence_dimension=True) * right_arg
|
||||
_test(result, expected, False)
|
||||
|
||||
@pytest.mark.parametrize("left_arg, right_arg", [
|
||||
#([30, 2], [10, 3]),
|
||||
([[30,40], [1,2]], [[30,40], [1,2]]),
|
||||
])
|
||||
|
||||
def test_elemmul_backward(left_arg, right_arg):
|
||||
|
||||
expected = AA(right_arg)
|
||||
# sequence of 1 element, since we have has_sequence_dimension=False
|
||||
expected = [expected]
|
||||
# batch of one sample
|
||||
expected = [expected]
|
||||
a = I([left_arg], has_sequence_dimension=False)
|
||||
m = a * right_arg
|
||||
_test(m, expected, clean_up=False, backward_pass = True, input_node = a)
|
||||
|
||||
expected = AA(left_arg)
|
||||
# sequence of 1 element, since we have has_sequence_dimension=False
|
||||
expected = [expected]
|
||||
# batch of one sample
|
||||
expected = [expected]
|
||||
b = I([left_arg], has_sequence_dimension=False)
|
||||
#b.var_name = 'v2'
|
||||
m = left_arg * b
|
||||
_test(m, expected, clean_up=False, backward_pass = True, input_node = b)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("left_arg, right_arg", [
|
||||
([0],[0]), # grad(Cos(0)) = -Sin(0) = 0
|
||||
([1.57079633],[-1]), # grad(Cos(pi/2)) = -Sin(pi/2) = -1
|
||||
|
|
Загрузка…
Ссылка в новой задаче