add splice unittest
This commit is contained in:
Родитель
e31e0be443
Коммит
a9d44401ad
|
@ -80,8 +80,6 @@ def test_op_slice(input_data, slice_params, expected_result, device_id, precisio
|
|||
# The second for batch of one sample.
|
||||
|
||||
a = I([input_data])
|
||||
def op_slice(x, beg_index, end_index, axis):
|
||||
return x[beg_index:end_index]
|
||||
|
||||
def _ax_slices(x, beg_index, end_index, axis):
|
||||
'''
|
||||
|
@ -232,6 +230,51 @@ def test_op_slice_overload(device_id, precision):
|
|||
result = a[1,object(),2]
|
||||
|
||||
|
||||
SPLICE_TEST_CASES = [
|
||||
#(input_data1, input_data2, axis, expected_result)
|
||||
([1], [2], 0, [1,2]),
|
||||
([[1,2],[4,5]], [[10,20],[30, 40],[50, 60]], 0, [[1, 2],[4, 5],[10, 20],[30, 40],[50, 60]]),
|
||||
([[1,2],[4,5]], [[10,20,30],[40, 50, 60]], 1, [[1,2,10,20,30],[4,5,40,50,60]]),
|
||||
([[[1,2],[3,4]],[[5,6],[7,8]]], [[10,20],[30,40]], 0, [[[1,2],[3,4]],[[5,6],[7,8]],[[10,20],[30,40]]]),
|
||||
]
|
||||
@pytest.mark.parametrize("input_data1, input_data2, axis, expected_result", SPLICE_TEST_CASES)
|
||||
def test_op_splice(input_data1, input_data2, axis, expected_result, device_id, precision):
|
||||
# Forward pass test
|
||||
#==================
|
||||
# We compute the expected output for the forward pass.
|
||||
# We need two surrounding brackets:
|
||||
# The first for sequences (length=1, since we have dynamic_axis='').
|
||||
# The second for batch of one sample.
|
||||
|
||||
a = I([input_data1])
|
||||
b = I([input_data2])
|
||||
|
||||
def op_splice(x, y, axis):
|
||||
return np.concatenate((x,y), axis)
|
||||
|
||||
# splice using the operator
|
||||
result = C.splice((a, b), axis)
|
||||
|
||||
unittest_helper(result, None, [[expected_result]], device_id=device_id,
|
||||
precision=precision, clean_up=True, backward_pass=False)
|
||||
|
||||
# Backward pass test
|
||||
# ==================
|
||||
# The gradient of the splice operator is all ones in the shape of the input
|
||||
|
||||
def grad_slice(x):
|
||||
return np.ones_like(x)
|
||||
|
||||
expected_gradient1 = grad_slice(np.asarray(input_data1))
|
||||
expected_gradient2 = grad_slice(np.asarray(input_data2))
|
||||
|
||||
unittest_helper(result, None, [[expected_gradient1]], device_id = device_id,
|
||||
precision=precision, clean_up=True, backward_pass=True, input_node=a)
|
||||
|
||||
unittest_helper(result, None, [[expected_gradient2]], device_id = device_id,
|
||||
precision=precision, clean_up=True, backward_pass=True, input_node=b)
|
||||
|
||||
|
||||
TRANSPOSE_DIMS_TEST_CASES = [
|
||||
#(input_shape, axis1, axis2, expected_output_shape)
|
||||
([2, 3], 0, 1, [3, 2]),
|
||||
|
|
Загрузка…
Ссылка в новой задаче