Restructuring tests
This commit is contained in:
Родитель
358082b5d6
Коммит
14d091b158
|
@ -9,82 +9,42 @@ from ..reader import *
|
|||
# keeping things short
|
||||
C = constant
|
||||
I = input
|
||||
|
||||
AA = np.asarray
|
||||
|
||||
def _test(root_node, expected, clean_up=True):
|
||||
with get_new_context() as ctx:
|
||||
ctx.clean_up = clean_up
|
||||
assert not ctx.input_nodes
|
||||
result = ctx.eval(root_node)
|
||||
expected = np.asarray(expected)
|
||||
expected = AA(expected)
|
||||
assert result.shape == expected.shape or result.shape == (
|
||||
1, 1) and expected.shape == ()
|
||||
assert np.all(result == expected)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
_VALUES = [0, [[1, 2], [3, 4]], [10.1, -20.2], 1.1]
|
||||
|
||||
@pytest.mark.parametrize('root_node, expected', [
|
||||
# __add__ / __radd__
|
||||
(C(0) + C(1), 1),
|
||||
(C(0) + 1, 1),
|
||||
(0 + C(1), 1),
|
||||
@pytest.fixture(scope="module", params=_VALUES)
|
||||
def left_arg(request):
|
||||
return request.param
|
||||
|
||||
# __sub__ / __rsub__
|
||||
(C(0) - C(1), -1),
|
||||
(C(0) - 1, -1),
|
||||
(0 - C(1), -1),
|
||||
right_arg = left_arg
|
||||
|
||||
# __mul__ / __rmul__ --> element-wise (!) multiplication
|
||||
(C(0) * C(1), 0),
|
||||
(C(0) * 1, 0),
|
||||
(0 * C(1), 0),
|
||||
def test_op_add(left_arg, right_arg):
|
||||
expected = AA(left_arg) + AA(right_arg)
|
||||
_test(C(left_arg) + right_arg, expected)
|
||||
_test(C(left_arg) + C(right_arg), expected)
|
||||
_test(left_arg + C(right_arg), expected)
|
||||
_test(left_arg + C(left_arg) + right_arg, left_arg+expected)
|
||||
|
||||
# chaining
|
||||
(C(2) * C(3) + C(1.2), 7.2),
|
||||
(C(2) * (C(3) + C(1.2)), 8.4),
|
||||
def test_op_minus(left_arg, right_arg):
|
||||
expected = AA(left_arg) - AA(right_arg)
|
||||
_test(C(left_arg) - right_arg, expected)
|
||||
_test(C(left_arg) - C(right_arg), expected)
|
||||
_test(left_arg - C(right_arg), expected)
|
||||
_test(left_arg - C(left_arg) + right_arg, left_arg-expected)
|
||||
|
||||
# normal ops
|
||||
|
||||
(C(np.ones((2, 3)) * 3), [[3, 3, 3], [3, 3, 3]]),
|
||||
(C(np.ones((2, 3)) * 3) + \
|
||||
np.vstack([np.ones(3), np.ones(3) + 1]), [[4, 4, 4], [5, 5, 5]]),
|
||||
(C(np.ones((2, 3)) * 3) * \
|
||||
np.vstack([np.ones(3), np.ones(3) + 1]), [[3, 3, 3], [6, 6, 6]]),
|
||||
|
||||
# special treatment of inputs in RowStack
|
||||
# (RowStack((C(1), C(2))), [[1],[2]]), # TODO figure out the real semantic
|
||||
# of RowStack
|
||||
|
||||
# the following test fails because Constant() ignores the cols parameter
|
||||
#(RowStack((C(1, rows=2, cols=2), C(2, rows=2, cols=2))), [[1,1,2,2], [1,1,2,2]])
|
||||
|
||||
# __abs__
|
||||
# uncomennt, once Abs() as ComputationNode is moved from standard function
|
||||
# to ComputationNode
|
||||
(abs(C(-3)), 3),
|
||||
(abs(C(3)), 3),
|
||||
(abs(C([[-1, 2], [50, -0]])), [[1, 2], [50, 0]]),
|
||||
|
||||
# more complex stuff
|
||||
#(Plus(C(5), 3), 8),
|
||||
])
|
||||
def test_overload_eval(root_node, expected):
|
||||
_test(root_node, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('root_node, expected', [
|
||||
# __add__ / __radd__
|
||||
(C(np.asarray([1, 2])) + 0, [1, 2]),
|
||||
(C(np.asarray([1, 2])) + .1, [1.1, 2.1]),
|
||||
(.1 + C(np.asarray([1, 2])), [1.1, 2.1]),
|
||||
(C(np.asarray([1, 2])) * 0, [0, 0]),
|
||||
(C(np.asarray([1, 2])) * .1, [0.1, 0.2]),
|
||||
(.1 * C(np.asarray([1, 2])), [0.1, 0.2]),
|
||||
(C(np.asarray([[1, 2], [3, 4]])) + .1, [[1.1, 2.1], [3.1, 4.1]]),
|
||||
(C(np.asarray([[1, 2], [3, 4]])) * 2, [[2, 4], [6, 8]]),
|
||||
(2 * C(np.asarray([[1, 2], [3, 4]])), [[2, 4], [6, 8]]),
|
||||
(2 * C(np.asarray([[1, 2], [3, 4]])) + 100, [[102, 104], [106, 108]]),
|
||||
(C(np.asarray([[1, 2], [3, 4]]))
|
||||
* C(np.asarray([[1, 2], [3, 4]])), [[1, 4], [9, 16]]),
|
||||
])
|
||||
def test_ops_on_numpy(root_node, expected, tmpdir):
|
||||
_test(root_node, expected, clean_up=False)
|
||||
def test_op_times(left_arg, right_arg):
|
||||
expected = AA(left_arg) * AA(right_arg)
|
||||
_test(C(left_arg) * right_arg, expected)
|
||||
_test(C(left_arg) * C(right_arg), expected)
|
||||
_test(left_arg * C(right_arg), expected)
|
||||
|
|
Загрузка…
Ссылка в новой задаче