fastgrnncuda: fix gradient return

This commit is contained in:
Moksh Jain 2019-09-27 14:03:49 +05:30
Родитель 4bcdf1796d
Коммит cb1e26f8ca
1 изменённых файлов: 8 добавлений и 10 удалений

Просмотреть файл

@ -397,6 +397,10 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
auto d_nu = torch::zeros_like(initial_h);
auto d_bias_z = torch::zeros_like(initial_h);
auto d_bias_h_prime = torch::zeros_like(initial_h);
auto d_w1 = torch::empty(0);
auto d_w2 = torch::empty(0);
auto d_u1 = torch::empty(0);
auto d_u2 = torch::empty(0);
bool w_low_rank = w1.size(0) != 0;
bool u_low_rank = u1.size(0) != 0;
@ -501,20 +505,14 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
d_nu = (d_nu.sum(0, true)).sum(1, true);
if (w_low_rank) {
auto d_w1 = torch::mm(w2.transpose(0, 1), d_w);
auto d_w2 = torch::mm(d_w, w1.transpose(0, 1));
d_w1 = torch::mm(w2.transpose(0, 1), d_w);
d_w2 = torch::mm(d_w, w1.transpose(0, 1));
d_w = torch::empty(0);
} else {
auto d_w1 = torch::empty(0);
auto d_w2 = torch::empty(0);
}
if(u_low_rank) {
auto d_u1 = torch::mm(u2.transpose(0, 1), d_u);
auto d_u2 = torch::mm(d_u, u1.transpose(0, 1));
d_u1 = torch::mm(u2.transpose(0, 1), d_u);
d_u2 = torch::mm(d_u, u1.transpose(0, 1));
d_u = torch::empty(0);
} else {
auto d_u1 = torch::empty(0);
auto d_u2 = torch::empty(0);
}
return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
}