зеркало из https://github.com/microsoft/EdgeML.git
FastGRNNCUDA: added unrolled implementation
This commit is contained in:
Родитель
e4ce97f7ce
Коммит
0b70d5e060
|
@ -25,6 +25,30 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
|
|||
torch::Tensor z,
|
||||
torch::Tensor h_prime);
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
|
||||
torch::Tensor input,
|
||||
torch::Tensor w,
|
||||
torch::Tensor u,
|
||||
torch::Tensor bias_z,
|
||||
torch::Tensor bias_h_prime,
|
||||
torch::Tensor zeta,
|
||||
torch::Tensor nu,
|
||||
torch::Tensor initial_h,
|
||||
int z_non_linearity);
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
|
||||
torch::Tensor grad_h,
|
||||
torch::Tensor input,
|
||||
torch::Tensor hidden_states,
|
||||
torch::Tensor zeta,
|
||||
torch::Tensor nu,
|
||||
torch::Tensor w,
|
||||
torch::Tensor u,
|
||||
torch::Tensor z,
|
||||
torch::Tensor h_prime,
|
||||
torch::Tensor initial_h,
|
||||
int z_non_linearity);
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
@ -75,7 +99,68 @@ std::vector<torch::Tensor> fastgrnn_backward(
|
|||
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, W, U, z_non_linearity, z, h_prime);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_unroll_forward(
|
||||
torch::Tensor input,
|
||||
torch::Tensor w,
|
||||
torch::Tensor u,
|
||||
torch::Tensor bias_z,
|
||||
torch::Tensor bias_h_prime,
|
||||
torch::Tensor zeta,
|
||||
torch::Tensor nu,
|
||||
torch::Tensor initial_h,
|
||||
int z_non_linearity) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(u);
|
||||
CHECK_INPUT(bias_z);
|
||||
CHECK_INPUT(bias_h_prime);
|
||||
CHECK_INPUT(initial_h);
|
||||
CHECK_INPUT(zeta);
|
||||
CHECK_INPUT(nu);
|
||||
return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_unroll_backward(
|
||||
torch::Tensor grad_h,
|
||||
torch::Tensor input,
|
||||
torch::Tensor hidden_states,
|
||||
torch::Tensor zeta,
|
||||
torch::Tensor nu,
|
||||
torch::Tensor w,
|
||||
torch::Tensor u,
|
||||
torch::Tensor z,
|
||||
torch::Tensor h_prime,
|
||||
torch::Tensor initial_h,
|
||||
int z_non_linearity) {
|
||||
CHECK_INPUT(grad_h);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(hidden_states);
|
||||
CHECK_INPUT(z);
|
||||
CHECK_INPUT(h_prime);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(u);
|
||||
CHECK_INPUT(zeta);
|
||||
CHECK_INPUT(nu);
|
||||
CHECK_INPUT(initial_h);
|
||||
|
||||
return fastgrnn_unroll_cuda_backward(
|
||||
grad_h,
|
||||
input,
|
||||
hidden_states,
|
||||
zeta,
|
||||
nu,
|
||||
w,
|
||||
u,
|
||||
z,
|
||||
h_prime,
|
||||
initial_h,
|
||||
z_non_linearity);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &fastgrnn_forward, "FastGRNN forward (CUDA)");
|
||||
m.def("backward", &fastgrnn_backward, "FastGRNN backward (CUDA)");
|
||||
m.def("forward_unroll", &fastgrnn_unroll_forward, "FastGRNN Unrolled forward (CUDA)");
|
||||
m.def("backward_unroll", &fastgrnn_unroll_backward, "FastGRNN Unrolled backward (CUDA)");
|
||||
}
|
||||
|
|
|
@ -87,6 +87,37 @@ __global__ void fastgrnn_cuda_backward_kernel(
|
|||
d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, scalar_t (*d_non_linearity) (scalar_t)>
|
||||
__global__ void fastgrnn_unroll_cuda_backward_kernel(
|
||||
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_precomp,
|
||||
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_h,
|
||||
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z,
|
||||
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime,
|
||||
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu,
|
||||
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta_sigmoid,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu_sigmoid,
|
||||
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h) {
|
||||
const int n = blockIdx.y;
|
||||
const int c = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (c < old_h.size(1)){
|
||||
d_old_h[n][c] = z[n][c] * grad_h[n][c];
|
||||
scalar_t temp_bias_h_prime = (zeta[0][0] * (1.0 - z[n][c]) + nu[0][0]) * d_tanh(h_prime[n][c]) * grad_h[n][c];
|
||||
scalar_t temp_bias_z = (old_h[n][c] - zeta[0][0] * h_prime[n][c]) * d_non_linearity(z[n][c]) * grad_h[n][c];
|
||||
d_bias_h_prime[n][c] += temp_bias_h_prime;
|
||||
d_bias_z[n][c] += temp_bias_z;
|
||||
d_precomp[n][c] = temp_bias_z + temp_bias_h_prime;
|
||||
d_zeta[n][c] += (1.0 - z[n][c]) * h_prime[n][c] * grad_h[n][c] * d_zeta_sigmoid[0][0];
|
||||
d_nu[n][c] += h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_cuda_forward(
|
||||
|
@ -246,3 +277,202 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
|
|||
|
||||
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
|
||||
torch::Tensor input,
|
||||
torch::Tensor w,
|
||||
torch::Tensor u,
|
||||
torch::Tensor bias_z,
|
||||
torch::Tensor bias_h_prime,
|
||||
torch::Tensor zeta,
|
||||
torch::Tensor nu,
|
||||
torch::Tensor initial_h,
|
||||
int z_non_linearity) {
|
||||
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device().type());
|
||||
const auto timesteps = input.size(0);
|
||||
const auto batch_size = initial_h.size(0);
|
||||
const auto state_size = initial_h.size(1);
|
||||
|
||||
auto hidden_states = torch::zeros({timesteps, batch_size, state_size}, options);
|
||||
auto z_s = torch::zeros_like(hidden_states);
|
||||
auto h_prime_s = torch::zeros_like(hidden_states);
|
||||
|
||||
auto prev_h = initial_h;
|
||||
auto new_h = torch::zeros_like(prev_h);
|
||||
auto z = torch::zeros_like(prev_h);
|
||||
auto h_prime = torch::zeros_like(prev_h);
|
||||
auto pre_comp = torch::zeros_like(prev_h);
|
||||
|
||||
const int threads = 1024;
|
||||
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
|
||||
|
||||
w = w.transpose(0, 1);
|
||||
u = u.transpose(0, 1);
|
||||
zeta = torch::sigmoid(zeta);
|
||||
nu = torch::sigmoid(nu);
|
||||
|
||||
for (int t=0; t < timesteps; t++) {
|
||||
pre_comp = torch::addmm(torch::mm(input[t], w), prev_h, u);
|
||||
|
||||
if (z_non_linearity == 0)
|
||||
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
|
||||
fastgrnn_cuda_forward_kernel<scalar_t, sigmoid><<<blocks, threads>>>(
|
||||
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
|
||||
}));
|
||||
else if(z_non_linearity == 1)
|
||||
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
|
||||
fastgrnn_cuda_forward_kernel<scalar_t, relu><<<blocks, threads>>>(
|
||||
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
|
||||
}));
|
||||
else if (z_non_linearity == 2)
|
||||
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
|
||||
fastgrnn_cuda_forward_kernel<scalar_t, tanh><<<blocks, threads>>>(
|
||||
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
|
||||
}));
|
||||
hidden_states[t] = new_h;
|
||||
z_s[t] = z;
|
||||
h_prime_s[t] = h_prime;
|
||||
prev_h = new_h;
|
||||
}
|
||||
return {hidden_states, z_s, h_prime_s};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
|
||||
torch::Tensor grad_h,
|
||||
torch::Tensor input,
|
||||
torch::Tensor hidden_states,
|
||||
torch::Tensor zeta,
|
||||
torch::Tensor nu,
|
||||
torch::Tensor w,
|
||||
torch::Tensor u,
|
||||
torch::Tensor z,
|
||||
torch::Tensor h_prime,
|
||||
torch::Tensor initial_h,
|
||||
int z_non_linearity) {
|
||||
|
||||
auto d_input = torch::zeros_like(input);
|
||||
auto d_w = torch::zeros_like(w);
|
||||
auto d_u = torch::zeros_like(u);
|
||||
auto d_zeta = torch::zeros_like(initial_h);
|
||||
auto d_nu = torch::zeros_like(initial_h);
|
||||
auto d_bias_z = torch::zeros_like(initial_h);
|
||||
auto d_bias_h_prime = torch::zeros_like(initial_h);
|
||||
|
||||
zeta = torch::sigmoid(zeta);
|
||||
nu = torch::sigmoid(nu);
|
||||
auto d_nu_sigmoid = d_sigmoid(nu);
|
||||
auto d_zeta_sigmoid = d_sigmoid(zeta);
|
||||
|
||||
|
||||
auto grad_curr_h = torch::zeros_like(initial_h);
|
||||
auto d_precomp = torch::zeros_like(initial_h);
|
||||
auto d_old_h = torch::zeros_like(initial_h);
|
||||
auto prev_h_ = hidden_states[0];
|
||||
auto z_t_ = torch::zeros_like(initial_h);
|
||||
auto h_prime_t_ = torch::zeros_like(initial_h);
|
||||
|
||||
const auto batch_size = hidden_states.size(1);
|
||||
const auto state_size = hidden_states.size(2);
|
||||
|
||||
const int threads = 1024;
|
||||
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
|
||||
for (auto t = hidden_states.size(0) - 1; t>=0; t--) {
|
||||
grad_curr_h = torch::add(grad_h[t], d_old_h);
|
||||
z_t_ = z[t];
|
||||
h_prime_t_ = h_prime[t];
|
||||
|
||||
if (t == 0)
|
||||
prev_h_ = initial_h;
|
||||
else
|
||||
prev_h_ = hidden_states[t-1];
|
||||
|
||||
if (z_non_linearity == 0)
|
||||
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
|
||||
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_sigmoid><<<blocks, threads>>>(
|
||||
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
|
||||
}));
|
||||
else if (z_non_linearity == 1)
|
||||
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
|
||||
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_relu><<<blocks, threads>>>(
|
||||
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
|
||||
}));
|
||||
else if(z_non_linearity == 2)
|
||||
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
|
||||
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_sigmoid><<<blocks, threads>>>(
|
||||
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
|
||||
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
|
||||
}));
|
||||
d_old_h = torch::addmm(d_old_h, d_precomp, u);
|
||||
d_input[t] = torch::mm(d_precomp, w);
|
||||
d_w = torch::addmm(d_w, d_precomp.transpose(0, 1), input[t]);
|
||||
d_u = torch::addmm(d_u, d_precomp.transpose(0, 1), prev_h_);
|
||||
// grad_curr_h = d_old_h;
|
||||
}
|
||||
d_bias_z = d_bias_z.sum(0, true);
|
||||
d_bias_h_prime = d_bias_h_prime.sum(0, true);
|
||||
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
|
||||
d_nu = (d_nu.sum(0, true)).sum(1, true);
|
||||
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
|
||||
}
|
||||
|
|
|
@ -1063,6 +1063,34 @@ class FastGRNN(nn.Module):
|
|||
def forward(self, input, hiddenState=None, cellState=None):
|
||||
return self.unrollRNN(input, hiddenState, cellState)
|
||||
|
||||
class FastGRNNCUDA(nn.Module):
|
||||
"""Unrolled implementation of the FastGRNNCUDACell"""
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
|
||||
super(FastGRNNCUDA, self).__init__()
|
||||
if utils.findCUDA() is None:
|
||||
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
|
||||
NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2}
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._zetaInit = zetaInit
|
||||
self._nuInit = nuInit
|
||||
self._name = name
|
||||
self._gate_non_linearity = NON_LINEARITY[gate_non_linearity]
|
||||
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
|
||||
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
|
||||
|
||||
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
|
||||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
|
||||
|
||||
def forward(self, input, h_state, cell_state=None):
|
||||
# input: [timesteps, batch, features, state_size]
|
||||
return FastGRNNUnrollFunction.apply(input, self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state, self._gate_non_linearity)
|
||||
|
||||
def getVars(self):
|
||||
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
|
||||
|
||||
class SRNN2(nn.Module):
|
||||
|
||||
def __init__(self, inputDim, outputDim, hiddenDim0, hiddenDim1, cellType,
|
||||
|
@ -1195,3 +1223,19 @@ class FastGRNNFunction(Function):
|
|||
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
|
||||
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h, None
|
||||
|
||||
class FastGRNNUnrollFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity):
|
||||
outputs = fastgrnn_cuda.forward_unroll(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity)
|
||||
hidden_states = outputs[0]
|
||||
variables = [input, hidden_states, zeta, nu, w, u] + outputs[1:] + [old_h]
|
||||
ctx.save_for_backward(*variables)
|
||||
ctx.gate_non_linearity = gate_non_linearity
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_h):
|
||||
outputs = fastgrnn_cuda.backward_unroll(
|
||||
grad_h.contiguous(), *ctx.saved_variables, ctx.gate_non_linearity)
|
||||
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
|
||||
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h
|
Загрузка…
Ссылка в новой задаче