GatherNode backward: add check for no dynamic axis
Previously, to resolve issue of gather producing incorrect gradient values, validity mask check was added to ensure we don't count non-valid cells as 0. However, this check is needed only for input that has dynamic axis, i.e. inputs that have MBLayout.
This commit is contained in:
Родитель
0a3eb3b813
Коммит
da6b0bc71f
|
@ -2093,7 +2093,6 @@ public:
|
||||||
if (inputIndex == 1) //only right operand need calculate gradient
|
if (inputIndex == 1) //only right operand need calculate gradient
|
||||||
{
|
{
|
||||||
let& indices = InputRef(0).Value();
|
let& indices = InputRef(0).Value();
|
||||||
const auto& indicesMask = InputRef(0).GetMBLayout()->GetColumnsValidityMask(indices.GetDeviceId());
|
|
||||||
auto& sourceGradient = InputRef(1).Gradient();
|
auto& sourceGradient = InputRef(1).Gradient();
|
||||||
auto& outputGradient = Gradient();
|
auto& outputGradient = Gradient();
|
||||||
const auto& sampleLayout = InputRef(1).GetSampleLayout();
|
const auto& sampleLayout = InputRef(1).GetSampleLayout();
|
||||||
|
@ -2110,9 +2109,17 @@ public:
|
||||||
row_elements *= dims[i];
|
row_elements *= dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (InputRef(0).HasMBLayout())
|
||||||
|
{
|
||||||
|
const auto& indicesMask = InputRef(0).GetMBLayout()->GetColumnsValidityMask(indices.GetDeviceId());
|
||||||
sourceGradient.ScatterToIndices(outputGradient, indices, row_elements, &indicesMask);
|
sourceGradient.ScatterToIndices(outputGradient, indices, row_elements, &indicesMask);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
{
|
||||||
|
sourceGradient.ScatterToIndices(outputGradient, indices, row_elements);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
//No graidents pass through indices (the left operand), so do nothing
|
//No graidents pass through indices (the left operand), so do nothing
|
||||||
}
|
}
|
||||||
|
|
|
@ -590,6 +590,28 @@ def test_gather_op_with_axis(device_id, precision):
|
||||||
assert np.allclose(output, z)
|
assert np.allclose(output, z)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gather_op_backward(device_id, precision):
|
||||||
|
a_data = [AA([[0],[1]], dtype=PRECISION_TO_TYPE[precision]),
|
||||||
|
AA([[3],[4]], dtype=PRECISION_TO_TYPE[precision])]
|
||||||
|
a = C.input_variable((2,1), dtype=PRECISION_TO_TYPE[precision])
|
||||||
|
r_data = np.arange(12).reshape(6,2).astype(PRECISION_TO_TYPE[precision])
|
||||||
|
r = C.parameter(shape=r_data.data, init=r_data)
|
||||||
|
g = C.gather(r, a)
|
||||||
|
grad = g.grad(a_data, wrt=[r])
|
||||||
|
expectd = np.asarray([[1., 1.], [1., 1.], [0., 0.], [1., 1.], [1., 1.], [0., 0.]]).astype(PRECISION_TO_TYPE[precision])
|
||||||
|
assert np.array_equal(grad, expectd)
|
||||||
|
|
||||||
|
# test without dynamic axis
|
||||||
|
data = np.array([ [1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9], ]).astype(PRECISION_TO_TYPE[precision])
|
||||||
|
indices = np.array([ 0, 2]).astype(PRECISION_TO_TYPE[precision]).astype(PRECISION_TO_TYPE[precision])
|
||||||
|
expectd = np.array([[1., 1., 1.], [0., 0., 0.], [1., 1., 1.]]).astype(PRECISION_TO_TYPE[precision])
|
||||||
|
x = C.input_variable(**C.layers.typing.ParameterTensor[3, 3], needs_gradient=True, dtype=PRECISION_TO_TYPE[precision])
|
||||||
|
i = C.constant(indices, dtype=PRECISION_TO_TYPE[precision])
|
||||||
|
y = C.gather(x, i)
|
||||||
|
grad = y.grad(data, wrt=[x])
|
||||||
|
assert np.allclose(expectd, grad)
|
||||||
|
|
||||||
|
|
||||||
def test_convert_dynamic_axis():
|
def test_convert_dynamic_axis():
|
||||||
#test fix batch size
|
#test fix batch size
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
Загрузка…
Ссылка в новой задаче