diff --git a/Source/ComputationNetworkLib/ReshapingNodes.h b/Source/ComputationNetworkLib/ReshapingNodes.h index c29c944c3..1ca9072da 100644 --- a/Source/ComputationNetworkLib/ReshapingNodes.h +++ b/Source/ComputationNetworkLib/ReshapingNodes.h @@ -150,12 +150,16 @@ public: virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override { - ValueFor(fr).AssignValuesOf(InputRef(0).ValueFor(fr)); + auto result = ValueFor(fr); + auto inputValue = InputRef(0).ValueFor(fr); + result.AssignValuesOf(inputValue.Reshaped(result.GetNumRows(), result.GetNumCols())); } virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override { - InputRef(inputIndex).GradientFor(fr) += GradientFor(fr); + auto gradient = GradientFor(fr); + auto inputGradient = InputRef(inputIndex).GradientFor(fr); + inputGradient += gradient.Reshaped(inputGradient.GetNumRows(), inputGradient.GetNumCols()); } virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; } diff --git a/bindings/python/cntk/ops/tests/reshaping_test.py b/bindings/python/cntk/ops/tests/reshaping_test.py index 2b207dc55..00f466026 100644 --- a/bindings/python/cntk/ops/tests/reshaping_test.py +++ b/bindings/python/cntk/ops/tests/reshaping_test.py @@ -157,13 +157,29 @@ def test_op_reshape_gradient_accumulation(device_id, precision): device_id=device_id, precision=precision) +def test_op_reshape_parameter(): + from .. import reshape, parameter + + param_shape = (4,2) + param_value = np.random.random(param_shape) + param = parameter(init=param_value) + param_new_shape = (8,1) + param_reshaped = reshape(param, param_new_shape) + + expected_forward = np.copy(param_value).reshape(param_new_shape) + state, result = param_reshaped.forward({}, [param_reshaped.output], [param_reshaped.output]) + np.allclose(result[param_reshaped.output], expected_forward) + + grad = param_reshaped.backward(state, np.ones(param_new_shape), [param]) + np.allclose(grad[param], np.ones(param_shape)) + + SLICE_TEST_CASES_STATIC = [ #(input_data, slice_params(beg_index, end_index, axis), expected_result) ([[1, 2], [-3, 4]], (1, 2, 0), [[-3, 4]]), ([[1,2],[-3,4]], (1,2,1), [[2],[4]]), ] - @pytest.mark.parametrize("input_data, slice_params, expected_result", SLICE_TEST_CASES_STATIC) def test_op_slice(input_data, slice_params, expected_result, device_id, precision):