FastGRNNCUDA: added unrolled implementation

This commit is contained in:
Moksh Jain 2019-09-26 18:59:04 +05:30
Родитель e4ce97f7ce
Коммит 0b70d5e060
3 изменённых файлов: 359 добавлений и 0 удалений

Просмотреть файл

@ -25,6 +25,30 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
torch::Tensor z,
torch::Tensor h_prime);
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
torch::Tensor input,
torch::Tensor w,
torch::Tensor u,
torch::Tensor bias_z,
torch::Tensor bias_h_prime,
torch::Tensor zeta,
torch::Tensor nu,
torch::Tensor initial_h,
int z_non_linearity);
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
torch::Tensor grad_h,
torch::Tensor input,
torch::Tensor hidden_states,
torch::Tensor zeta,
torch::Tensor nu,
torch::Tensor w,
torch::Tensor u,
torch::Tensor z,
torch::Tensor h_prime,
torch::Tensor initial_h,
int z_non_linearity);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
@ -75,7 +99,68 @@ std::vector<torch::Tensor> fastgrnn_backward(
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, W, U, z_non_linearity, z, h_prime);
}
std::vector<torch::Tensor> fastgrnn_unroll_forward(
torch::Tensor input,
torch::Tensor w,
torch::Tensor u,
torch::Tensor bias_z,
torch::Tensor bias_h_prime,
torch::Tensor zeta,
torch::Tensor nu,
torch::Tensor initial_h,
int z_non_linearity) {
CHECK_INPUT(input);
CHECK_INPUT(w);
CHECK_INPUT(u);
CHECK_INPUT(bias_z);
CHECK_INPUT(bias_h_prime);
CHECK_INPUT(initial_h);
CHECK_INPUT(zeta);
CHECK_INPUT(nu);
return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity);
}
std::vector<torch::Tensor> fastgrnn_unroll_backward(
torch::Tensor grad_h,
torch::Tensor input,
torch::Tensor hidden_states,
torch::Tensor zeta,
torch::Tensor nu,
torch::Tensor w,
torch::Tensor u,
torch::Tensor z,
torch::Tensor h_prime,
torch::Tensor initial_h,
int z_non_linearity) {
CHECK_INPUT(grad_h);
CHECK_INPUT(input);
CHECK_INPUT(hidden_states);
CHECK_INPUT(z);
CHECK_INPUT(h_prime);
CHECK_INPUT(w);
CHECK_INPUT(u);
CHECK_INPUT(zeta);
CHECK_INPUT(nu);
CHECK_INPUT(initial_h);
return fastgrnn_unroll_cuda_backward(
grad_h,
input,
hidden_states,
zeta,
nu,
w,
u,
z,
h_prime,
initial_h,
z_non_linearity);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fastgrnn_forward, "FastGRNN forward (CUDA)");
m.def("backward", &fastgrnn_backward, "FastGRNN backward (CUDA)");
m.def("forward_unroll", &fastgrnn_unroll_forward, "FastGRNN Unrolled forward (CUDA)");
m.def("backward_unroll", &fastgrnn_unroll_backward, "FastGRNN Unrolled backward (CUDA)");
}

Просмотреть файл

@ -87,6 +87,37 @@ __global__ void fastgrnn_cuda_backward_kernel(
d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
}
}
template <typename scalar_t, scalar_t (*d_non_linearity) (scalar_t)>
__global__ void fastgrnn_unroll_cuda_backward_kernel(
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_precomp,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta_sigmoid,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu_sigmoid,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h) {
const int n = blockIdx.y;
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < old_h.size(1)){
d_old_h[n][c] = z[n][c] * grad_h[n][c];
scalar_t temp_bias_h_prime = (zeta[0][0] * (1.0 - z[n][c]) + nu[0][0]) * d_tanh(h_prime[n][c]) * grad_h[n][c];
scalar_t temp_bias_z = (old_h[n][c] - zeta[0][0] * h_prime[n][c]) * d_non_linearity(z[n][c]) * grad_h[n][c];
d_bias_h_prime[n][c] += temp_bias_h_prime;
d_bias_z[n][c] += temp_bias_z;
d_precomp[n][c] = temp_bias_z + temp_bias_h_prime;
d_zeta[n][c] += (1.0 - z[n][c]) * h_prime[n][c] * grad_h[n][c] * d_zeta_sigmoid[0][0];
d_nu[n][c] += h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
}
}
} // namespace
std::vector<torch::Tensor> fastgrnn_cuda_forward(
@ -246,3 +277,202 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
}
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
torch::Tensor input,
torch::Tensor w,
torch::Tensor u,
torch::Tensor bias_z,
torch::Tensor bias_h_prime,
torch::Tensor zeta,
torch::Tensor nu,
torch::Tensor initial_h,
int z_non_linearity) {
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device().type());
const auto timesteps = input.size(0);
const auto batch_size = initial_h.size(0);
const auto state_size = initial_h.size(1);
auto hidden_states = torch::zeros({timesteps, batch_size, state_size}, options);
auto z_s = torch::zeros_like(hidden_states);
auto h_prime_s = torch::zeros_like(hidden_states);
auto prev_h = initial_h;
auto new_h = torch::zeros_like(prev_h);
auto z = torch::zeros_like(prev_h);
auto h_prime = torch::zeros_like(prev_h);
auto pre_comp = torch::zeros_like(prev_h);
const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
w = w.transpose(0, 1);
u = u.transpose(0, 1);
zeta = torch::sigmoid(zeta);
nu = torch::sigmoid(nu);
for (int t=0; t < timesteps; t++) {
pre_comp = torch::addmm(torch::mm(input[t], w), prev_h, u);
if (z_non_linearity == 0)
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
fastgrnn_cuda_forward_kernel<scalar_t, sigmoid><<<blocks, threads>>>(
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
else if(z_non_linearity == 1)
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
fastgrnn_cuda_forward_kernel<scalar_t, relu><<<blocks, threads>>>(
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
else if (z_non_linearity == 2)
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
fastgrnn_cuda_forward_kernel<scalar_t, tanh><<<blocks, threads>>>(
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
hidden_states[t] = new_h;
z_s[t] = z;
h_prime_s[t] = h_prime;
prev_h = new_h;
}
return {hidden_states, z_s, h_prime_s};
}
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
torch::Tensor grad_h,
torch::Tensor input,
torch::Tensor hidden_states,
torch::Tensor zeta,
torch::Tensor nu,
torch::Tensor w,
torch::Tensor u,
torch::Tensor z,
torch::Tensor h_prime,
torch::Tensor initial_h,
int z_non_linearity) {
auto d_input = torch::zeros_like(input);
auto d_w = torch::zeros_like(w);
auto d_u = torch::zeros_like(u);
auto d_zeta = torch::zeros_like(initial_h);
auto d_nu = torch::zeros_like(initial_h);
auto d_bias_z = torch::zeros_like(initial_h);
auto d_bias_h_prime = torch::zeros_like(initial_h);
zeta = torch::sigmoid(zeta);
nu = torch::sigmoid(nu);
auto d_nu_sigmoid = d_sigmoid(nu);
auto d_zeta_sigmoid = d_sigmoid(zeta);
auto grad_curr_h = torch::zeros_like(initial_h);
auto d_precomp = torch::zeros_like(initial_h);
auto d_old_h = torch::zeros_like(initial_h);
auto prev_h_ = hidden_states[0];
auto z_t_ = torch::zeros_like(initial_h);
auto h_prime_t_ = torch::zeros_like(initial_h);
const auto batch_size = hidden_states.size(1);
const auto state_size = hidden_states.size(2);
const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
for (auto t = hidden_states.size(0) - 1; t>=0; t--) {
grad_curr_h = torch::add(grad_h[t], d_old_h);
z_t_ = z[t];
h_prime_t_ = h_prime[t];
if (t == 0)
prev_h_ = initial_h;
else
prev_h_ = hidden_states[t-1];
if (z_non_linearity == 0)
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_sigmoid><<<blocks, threads>>>(
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
else if (z_non_linearity == 1)
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_relu><<<blocks, threads>>>(
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
else if(z_non_linearity == 2)
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_sigmoid><<<blocks, threads>>>(
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
d_old_h = torch::addmm(d_old_h, d_precomp, u);
d_input[t] = torch::mm(d_precomp, w);
d_w = torch::addmm(d_w, d_precomp.transpose(0, 1), input[t]);
d_u = torch::addmm(d_u, d_precomp.transpose(0, 1), prev_h_);
// grad_curr_h = d_old_h;
}
d_bias_z = d_bias_z.sum(0, true);
d_bias_h_prime = d_bias_h_prime.sum(0, true);
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
d_nu = (d_nu.sum(0, true)).sum(1, true);
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
}

Просмотреть файл

@ -1063,6 +1063,34 @@ class FastGRNN(nn.Module):
def forward(self, input, hiddenState=None, cellState=None):
return self.unrollRNN(input, hiddenState, cellState)
class FastGRNNCUDA(nn.Module):
"""Unrolled implementation of the FastGRNNCUDACell"""
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
super(FastGRNNCUDA, self).__init__()
if utils.findCUDA() is None:
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2}
self._input_size = input_size
self._hidden_size = hidden_size
self._zetaInit = zetaInit
self._nuInit = nuInit
self._name = name
self._gate_non_linearity = NON_LINEARITY[gate_non_linearity]
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
def forward(self, input, h_state, cell_state=None):
# input: [timesteps, batch, features, state_size]
return FastGRNNUnrollFunction.apply(input, self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state, self._gate_non_linearity)
def getVars(self):
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
class SRNN2(nn.Module):
def __init__(self, inputDim, outputDim, hiddenDim0, hiddenDim1, cellType,
@ -1195,3 +1223,19 @@ class FastGRNNFunction(Function):
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h, None
class FastGRNNUnrollFunction(Function):
@staticmethod
def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity):
outputs = fastgrnn_cuda.forward_unroll(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity)
hidden_states = outputs[0]
variables = [input, hidden_states, zeta, nu, w, u] + outputs[1:] + [old_h]
ctx.save_for_backward(*variables)
ctx.gate_non_linearity = gate_non_linearity
return hidden_states
@staticmethod
def backward(ctx, grad_h):
outputs = fastgrnn_cuda.backward_unroll(
grad_h.contiguous(), *ctx.saved_variables, ctx.gate_non_linearity)
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h