зеркало из https://github.com/microsoft/EdgeML.git
fastgrnncuda: fix gradient return
This commit is contained in:
Родитель
4bcdf1796d
Коммит
cb1e26f8ca
|
@ -397,6 +397,10 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
|
|||
auto d_nu = torch::zeros_like(initial_h);
|
||||
auto d_bias_z = torch::zeros_like(initial_h);
|
||||
auto d_bias_h_prime = torch::zeros_like(initial_h);
|
||||
auto d_w1 = torch::empty(0);
|
||||
auto d_w2 = torch::empty(0);
|
||||
auto d_u1 = torch::empty(0);
|
||||
auto d_u2 = torch::empty(0);
|
||||
|
||||
bool w_low_rank = w1.size(0) != 0;
|
||||
bool u_low_rank = u1.size(0) != 0;
|
||||
|
@ -501,20 +505,14 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
|
|||
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
|
||||
d_nu = (d_nu.sum(0, true)).sum(1, true);
|
||||
if (w_low_rank) {
|
||||
auto d_w1 = torch::mm(w2.transpose(0, 1), d_w);
|
||||
auto d_w2 = torch::mm(d_w, w1.transpose(0, 1));
|
||||
d_w1 = torch::mm(w2.transpose(0, 1), d_w);
|
||||
d_w2 = torch::mm(d_w, w1.transpose(0, 1));
|
||||
d_w = torch::empty(0);
|
||||
} else {
|
||||
auto d_w1 = torch::empty(0);
|
||||
auto d_w2 = torch::empty(0);
|
||||
}
|
||||
if(u_low_rank) {
|
||||
auto d_u1 = torch::mm(u2.transpose(0, 1), d_u);
|
||||
auto d_u2 = torch::mm(d_u, u1.transpose(0, 1));
|
||||
d_u1 = torch::mm(u2.transpose(0, 1), d_u);
|
||||
d_u2 = torch::mm(d_u, u1.transpose(0, 1));
|
||||
d_u = torch::empty(0);
|
||||
} else {
|
||||
auto d_u1 = torch::empty(0);
|
||||
auto d_u2 = torch::empty(0);
|
||||
}
|
||||
return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
|
||||
}
|
Загрузка…
Ссылка в новой задаче