GatherNode backward: add check for no dynamic axis

Previously, to resolve issue of gather producing incorrect gradient
values, validity mask check was added to ensure we don't count non-valid
cells as 0.
However, this check is needed only for input that has dynamic axis, i.e.
inputs that have MBLayout.
This commit is contained in:
Bowen Bao 2018-09-20 14:51:55 -07:00
Родитель 0a3eb3b813
Коммит da6b0bc71f
2 изменённых файлов: 31 добавлений и 2 удалений

Просмотреть файл

@ -2093,7 +2093,6 @@ public:
if (inputIndex == 1) //only right operand need calculate gradient if (inputIndex == 1) //only right operand need calculate gradient
{ {
let& indices = InputRef(0).Value(); let& indices = InputRef(0).Value();
const auto& indicesMask = InputRef(0).GetMBLayout()->GetColumnsValidityMask(indices.GetDeviceId());
auto& sourceGradient = InputRef(1).Gradient(); auto& sourceGradient = InputRef(1).Gradient();
auto& outputGradient = Gradient(); auto& outputGradient = Gradient();
const auto& sampleLayout = InputRef(1).GetSampleLayout(); const auto& sampleLayout = InputRef(1).GetSampleLayout();
@ -2110,9 +2109,17 @@ public:
row_elements *= dims[i]; row_elements *= dims[i];
} }
if (InputRef(0).HasMBLayout())
{
const auto& indicesMask = InputRef(0).GetMBLayout()->GetColumnsValidityMask(indices.GetDeviceId());
sourceGradient.ScatterToIndices(outputGradient, indices, row_elements, &indicesMask); sourceGradient.ScatterToIndices(outputGradient, indices, row_elements, &indicesMask);
} }
else else
{
sourceGradient.ScatterToIndices(outputGradient, indices, row_elements);
}
}
else
{ {
//No graidents pass through indices (the left operand), so do nothing //No graidents pass through indices (the left operand), so do nothing
} }

Просмотреть файл

@ -590,6 +590,28 @@ def test_gather_op_with_axis(device_id, precision):
assert np.allclose(output, z) assert np.allclose(output, z)
def test_gather_op_backward(device_id, precision):
a_data = [AA([[0],[1]], dtype=PRECISION_TO_TYPE[precision]),
AA([[3],[4]], dtype=PRECISION_TO_TYPE[precision])]
a = C.input_variable((2,1), dtype=PRECISION_TO_TYPE[precision])
r_data = np.arange(12).reshape(6,2).astype(PRECISION_TO_TYPE[precision])
r = C.parameter(shape=r_data.data, init=r_data)
g = C.gather(r, a)
grad = g.grad(a_data, wrt=[r])
expectd = np.asarray([[1., 1.], [1., 1.], [0., 0.], [1., 1.], [1., 1.], [0., 0.]]).astype(PRECISION_TO_TYPE[precision])
assert np.array_equal(grad, expectd)
# test without dynamic axis
data = np.array([ [1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9], ]).astype(PRECISION_TO_TYPE[precision])
indices = np.array([ 0, 2]).astype(PRECISION_TO_TYPE[precision]).astype(PRECISION_TO_TYPE[precision])
expectd = np.array([[1., 1., 1.], [0., 0., 0.], [1., 1., 1.]]).astype(PRECISION_TO_TYPE[precision])
x = C.input_variable(**C.layers.typing.ParameterTensor[3, 3], needs_gradient=True, dtype=PRECISION_TO_TYPE[precision])
i = C.constant(indices, dtype=PRECISION_TO_TYPE[precision])
y = C.gather(x, i)
grad = y.grad(data, wrt=[x])
assert np.allclose(expectd, grad)
def test_convert_dynamic_axis(): def test_convert_dynamic_axis():
#test fix batch size #test fix batch size
batch_size = 4 batch_size = 4