This commit is contained in:
jeanfad 2016-06-03 10:40:47 +02:00
Родитель e31e0be443
Коммит a9d44401ad
1 изменённых файлов: 45 добавлений и 2 удалений

Просмотреть файл

@ -80,8 +80,6 @@ def test_op_slice(input_data, slice_params, expected_result, device_id, precisio
# The second for batch of one sample.
a = I([input_data])
def op_slice(x, beg_index, end_index, axis):
return x[beg_index:end_index]
def _ax_slices(x, beg_index, end_index, axis):
'''
@ -232,6 +230,51 @@ def test_op_slice_overload(device_id, precision):
result = a[1,object(),2]
SPLICE_TEST_CASES = [
#(input_data1, input_data2, axis, expected_result)
([1], [2], 0, [1,2]),
([[1,2],[4,5]], [[10,20],[30, 40],[50, 60]], 0, [[1, 2],[4, 5],[10, 20],[30, 40],[50, 60]]),
([[1,2],[4,5]], [[10,20,30],[40, 50, 60]], 1, [[1,2,10,20,30],[4,5,40,50,60]]),
([[[1,2],[3,4]],[[5,6],[7,8]]], [[10,20],[30,40]], 0, [[[1,2],[3,4]],[[5,6],[7,8]],[[10,20],[30,40]]]),
]
@pytest.mark.parametrize("input_data1, input_data2, axis, expected_result", SPLICE_TEST_CASES)
def test_op_splice(input_data1, input_data2, axis, expected_result, device_id, precision):
# Forward pass test
#==================
# We compute the expected output for the forward pass.
# We need two surrounding brackets:
# The first for sequences (length=1, since we have dynamic_axis='').
# The second for batch of one sample.
a = I([input_data1])
b = I([input_data2])
def op_splice(x, y, axis):
return np.concatenate((x,y), axis)
# splice using the operator
result = C.splice((a, b), axis)
unittest_helper(result, None, [[expected_result]], device_id=device_id,
precision=precision, clean_up=True, backward_pass=False)
# Backward pass test
# ==================
# The gradient of the splice operator is all ones in the shape of the input
def grad_slice(x):
return np.ones_like(x)
expected_gradient1 = grad_slice(np.asarray(input_data1))
expected_gradient2 = grad_slice(np.asarray(input_data2))
unittest_helper(result, None, [[expected_gradient1]], device_id = device_id,
precision=precision, clean_up=True, backward_pass=True, input_node=a)
unittest_helper(result, None, [[expected_gradient2]], device_id = device_id,
precision=precision, clean_up=True, backward_pass=True, input_node=b)
TRANSPOSE_DIMS_TEST_CASES = [
#(input_shape, axis1, axis2, expected_output_shape)
([2, 3], 0, 1, [3, 2]),