From cb1e26f8ca09ecabfc9ef2da3b8216b8a9fa829c Mon Sep 17 00:00:00 2001 From: Moksh Jain Date: Fri, 27 Sep 2019 14:03:49 +0530 Subject: [PATCH] fastgrnncuda: fix gradient return --- .../cuda/fastgrnn_cuda_kernel.cu | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu b/pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu index 9d72f089..50fc3ce4 100644 --- a/pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu +++ b/pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu @@ -397,6 +397,10 @@ std::vector fastgrnn_unroll_cuda_backward( auto d_nu = torch::zeros_like(initial_h); auto d_bias_z = torch::zeros_like(initial_h); auto d_bias_h_prime = torch::zeros_like(initial_h); + auto d_w1 = torch::empty(0); + auto d_w2 = torch::empty(0); + auto d_u1 = torch::empty(0); + auto d_u2 = torch::empty(0); bool w_low_rank = w1.size(0) != 0; bool u_low_rank = u1.size(0) != 0; @@ -501,20 +505,14 @@ std::vector fastgrnn_unroll_cuda_backward( d_zeta = (d_zeta.sum(0, true)).sum(1, true); d_nu = (d_nu.sum(0, true)).sum(1, true); if (w_low_rank) { - auto d_w1 = torch::mm(w2.transpose(0, 1), d_w); - auto d_w2 = torch::mm(d_w, w1.transpose(0, 1)); + d_w1 = torch::mm(w2.transpose(0, 1), d_w); + d_w2 = torch::mm(d_w, w1.transpose(0, 1)); d_w = torch::empty(0); - } else { - auto d_w1 = torch::empty(0); - auto d_w2 = torch::empty(0); } if(u_low_rank) { - auto d_u1 = torch::mm(u2.transpose(0, 1), d_u); - auto d_u2 = torch::mm(d_u, u1.transpose(0, 1)); + d_u1 = torch::mm(u2.transpose(0, 1), d_u); + d_u2 = torch::mm(d_u, u1.transpose(0, 1)); d_u = torch::empty(0); - } else { - auto d_u1 = torch::empty(0); - auto d_u2 = torch::empty(0); } return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2}; } \ No newline at end of file