From da6b0bc71fbbca14556ccad34ffbe51ea611187c Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Thu, 20 Sep 2018 14:51:55 -0700 Subject: [PATCH] GatherNode backward: add check for no dynamic axis Previously, to resolve issue of gather producing incorrect gradient values, validity mask check was added to ensure we don't count non-valid cells as 0. However, this check is needed only for input that has dynamic axis, i.e. inputs that have MBLayout. --- Source/ComputationNetworkLib/ReshapingNodes.h | 11 ++++++++-- .../python/cntk/ops/tests/reshaping_test.py | 22 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/Source/ComputationNetworkLib/ReshapingNodes.h b/Source/ComputationNetworkLib/ReshapingNodes.h index 25bb1aaf2..49cf906a9 100644 --- a/Source/ComputationNetworkLib/ReshapingNodes.h +++ b/Source/ComputationNetworkLib/ReshapingNodes.h @@ -2093,7 +2093,6 @@ public: if (inputIndex == 1) //only right operand need calculate gradient { let& indices = InputRef(0).Value(); - const auto& indicesMask = InputRef(0).GetMBLayout()->GetColumnsValidityMask(indices.GetDeviceId()); auto& sourceGradient = InputRef(1).Gradient(); auto& outputGradient = Gradient(); const auto& sampleLayout = InputRef(1).GetSampleLayout(); @@ -2110,7 +2109,15 @@ public: row_elements *= dims[i]; } - sourceGradient.ScatterToIndices(outputGradient, indices, row_elements, &indicesMask); + if (InputRef(0).HasMBLayout()) + { + const auto& indicesMask = InputRef(0).GetMBLayout()->GetColumnsValidityMask(indices.GetDeviceId()); + sourceGradient.ScatterToIndices(outputGradient, indices, row_elements, &indicesMask); + } + else + { + sourceGradient.ScatterToIndices(outputGradient, indices, row_elements); + } } else { diff --git a/bindings/python/cntk/ops/tests/reshaping_test.py b/bindings/python/cntk/ops/tests/reshaping_test.py index efb11a40f..4a397753b 100644 --- a/bindings/python/cntk/ops/tests/reshaping_test.py +++ b/bindings/python/cntk/ops/tests/reshaping_test.py @@ -590,6 +590,28 @@ def test_gather_op_with_axis(device_id, precision): assert np.allclose(output, z) +def test_gather_op_backward(device_id, precision): + a_data = [AA([[0],[1]], dtype=PRECISION_TO_TYPE[precision]), + AA([[3],[4]], dtype=PRECISION_TO_TYPE[precision])] + a = C.input_variable((2,1), dtype=PRECISION_TO_TYPE[precision]) + r_data = np.arange(12).reshape(6,2).astype(PRECISION_TO_TYPE[precision]) + r = C.parameter(shape=r_data.data, init=r_data) + g = C.gather(r, a) + grad = g.grad(a_data, wrt=[r]) + expectd = np.asarray([[1., 1.], [1., 1.], [0., 0.], [1., 1.], [1., 1.], [0., 0.]]).astype(PRECISION_TO_TYPE[precision]) + assert np.array_equal(grad, expectd) + + # test without dynamic axis + data = np.array([ [1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9], ]).astype(PRECISION_TO_TYPE[precision]) + indices = np.array([ 0, 2]).astype(PRECISION_TO_TYPE[precision]).astype(PRECISION_TO_TYPE[precision]) + expectd = np.array([[1., 1., 1.], [0., 0., 0.], [1., 1., 1.]]).astype(PRECISION_TO_TYPE[precision]) + x = C.input_variable(**C.layers.typing.ParameterTensor[3, 3], needs_gradient=True, dtype=PRECISION_TO_TYPE[precision]) + i = C.constant(indices, dtype=PRECISION_TO_TYPE[precision]) + y = C.gather(x, i) + grad = y.grad(data, wrt=[x]) + assert np.allclose(expectd, grad) + + def test_convert_dynamic_axis(): #test fix batch size batch_size = 4