CNTK v2 library: Fix a bug in Reshape where it failed to handle the new shape's leading axis dimensionality being 1
This commit is contained in:
Родитель
9c8884eaa0
Коммит
5b8d122a4b
|
@ -150,12 +150,16 @@ public:
|
|||
|
||||
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
|
||||
{
|
||||
ValueFor(fr).AssignValuesOf(InputRef(0).ValueFor(fr));
|
||||
auto result = ValueFor(fr);
|
||||
auto inputValue = InputRef(0).ValueFor(fr);
|
||||
result.AssignValuesOf(inputValue.Reshaped(result.GetNumRows(), result.GetNumCols()));
|
||||
}
|
||||
|
||||
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
|
||||
{
|
||||
InputRef(inputIndex).GradientFor(fr) += GradientFor(fr);
|
||||
auto gradient = GradientFor(fr);
|
||||
auto inputGradient = InputRef(inputIndex).GradientFor(fr);
|
||||
inputGradient += gradient.Reshaped(inputGradient.GetNumRows(), inputGradient.GetNumCols());
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
|
||||
|
|
|
@ -157,13 +157,29 @@ def test_op_reshape_gradient_accumulation(device_id, precision):
|
|||
device_id=device_id, precision=precision)
|
||||
|
||||
|
||||
def test_op_reshape_parameter():
|
||||
from .. import reshape, parameter
|
||||
|
||||
param_shape = (4,2)
|
||||
param_value = np.random.random(param_shape)
|
||||
param = parameter(init=param_value)
|
||||
param_new_shape = (8,1)
|
||||
param_reshaped = reshape(param, param_new_shape)
|
||||
|
||||
expected_forward = np.copy(param_value).reshape(param_new_shape)
|
||||
state, result = param_reshaped.forward({}, [param_reshaped.output], [param_reshaped.output])
|
||||
np.allclose(result[param_reshaped.output], expected_forward)
|
||||
|
||||
grad = param_reshaped.backward(state, np.ones(param_new_shape), [param])
|
||||
np.allclose(grad[param], np.ones(param_shape))
|
||||
|
||||
|
||||
SLICE_TEST_CASES_STATIC = [
|
||||
#(input_data, slice_params(beg_index, end_index, axis), expected_result)
|
||||
([[1, 2], [-3, 4]], (1, 2, 0), [[-3, 4]]),
|
||||
([[1,2],[-3,4]], (1,2,1), [[2],[4]]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_data, slice_params, expected_result",
|
||||
SLICE_TEST_CASES_STATIC)
|
||||
def test_op_slice(input_data, slice_params, expected_result, device_id, precision):
|
||||
|
|
Загрузка…
Ссылка в новой задаче