fix Gather op's incorrect gradient value.
* the error was due to that we pad 0 as default value for missing gaps. All these then each contribute 1 to the gradient of reference at index 0. The fix is to mask missing values in indices matrix to negative, and in Matrix scatter implementation to check and skip negative indices. (previous Matrix CPU implementation already checks for negative indices)
This commit is contained in:
Родитель
07fd96f357
Коммит
1e058cedcf
|
@ -2092,6 +2092,7 @@ public:
|
|||
{
|
||||
if (inputIndex == 1) //only right operand need calculate gradient
|
||||
{
|
||||
InputRef(0).MaskMissingValueColumnsTo(FrameRange(InputRef(0).GetMBLayout()), (ElemType) -1.0);
|
||||
let& indices = InputRef(0).Value();
|
||||
auto& sourceGradient = InputRef(1).Gradient();
|
||||
auto& outputGradient = Gradient();
|
||||
|
|
|
@ -5831,6 +5831,8 @@ __global__ void _scatterToIndices(ElemType *indices,
|
|||
{
|
||||
size_t indices_index = index / num_row_elements;
|
||||
size_t offset = index % num_row_elements;
|
||||
//Skip missing values
|
||||
if (indices[indices_index] < 0) return;
|
||||
//We resort to nondeterministic behavior (floating point addition is not associative).
|
||||
//Note that the CPU parallel algorithm will have poor performance on the GPU because of thread divergence
|
||||
atomicAdd(&buffer[(size_t)(unsigned long long int)indices[indices_index] * num_row_elements + offset], value[index]);
|
||||
|
|
|
@ -499,6 +499,18 @@ def test_op_gather_sparse(device_id):
|
|||
assert np.array_equal(res, [[[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]], [[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0]], [[1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]]])
|
||||
|
||||
|
||||
def test_op_gather_grad(device_id):
|
||||
dim = 10
|
||||
ii = C.sequence.input_variable(())
|
||||
param = C.parameter((dim, 1), init=np.reshape(np.arange(dim), (dim,1)).astype(np.float32))
|
||||
ss = C.gather(param, ii)
|
||||
data = [[0], [0,1,2], [1,2,3,4,5, 6]]
|
||||
grad1 = ss.grad(data, wrt=[param])
|
||||
ss2 = C.times(C.one_hot(ii, num_classes=dim, sparse_output=False), param)
|
||||
grad2 = ss2.grad(data, wrt=[param])
|
||||
assert np.array_equal(grad1, grad2)
|
||||
|
||||
|
||||
def test_op_scatter_sparse(device_id):
|
||||
input_sparse_indices = [[1, 3, 5, 5], [2, 4], [0, 2]]
|
||||
vocab_size = 6
|
||||
|
|
Загрузка…
Ссылка в новой задаче