CNTK v2 library: Fix a bug in Reshape where it failed to handle the new shape's leading axis dimensionality being 1

This commit is contained in:
Amit Agarwal 2017-02-05 21:59:04 -08:00
Родитель 9c8884eaa0
Коммит 5b8d122a4b
2 изменённых файлов: 23 добавлений и 3 удалений

Просмотреть файл

@ -150,12 +150,16 @@ public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
ValueFor(fr).AssignValuesOf(InputRef(0).ValueFor(fr));
auto result = ValueFor(fr);
auto inputValue = InputRef(0).ValueFor(fr);
result.AssignValuesOf(inputValue.Reshaped(result.GetNumRows(), result.GetNumCols()));
}
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
InputRef(inputIndex).GradientFor(fr) += GradientFor(fr);
auto gradient = GradientFor(fr);
auto inputGradient = InputRef(inputIndex).GradientFor(fr);
inputGradient += gradient.Reshaped(inputGradient.GetNumRows(), inputGradient.GetNumCols());
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }

Просмотреть файл

@ -157,13 +157,29 @@ def test_op_reshape_gradient_accumulation(device_id, precision):
device_id=device_id, precision=precision)
def test_op_reshape_parameter():
from .. import reshape, parameter
param_shape = (4,2)
param_value = np.random.random(param_shape)
param = parameter(init=param_value)
param_new_shape = (8,1)
param_reshaped = reshape(param, param_new_shape)
expected_forward = np.copy(param_value).reshape(param_new_shape)
state, result = param_reshaped.forward({}, [param_reshaped.output], [param_reshaped.output])
np.allclose(result[param_reshaped.output], expected_forward)
grad = param_reshaped.backward(state, np.ones(param_new_shape), [param])
np.allclose(grad[param], np.ones(param_shape))
SLICE_TEST_CASES_STATIC = [
#(input_data, slice_params(beg_index, end_index, axis), expected_result)
([[1, 2], [-3, 4]], (1, 2, 0), [[-3, 4]]),
([[1,2],[-3,4]], (1,2,1), [[2],[4]]),
]
@pytest.mark.parametrize("input_data, slice_params, expected_result",
SLICE_TEST_CASES_STATIC)
def test_op_slice(input_data, slice_params, expected_result, device_id, precision):