Support tensor rank either 1 or 2 in CUDA vrelu3 (#815)

This commit is contained in:
David Collier 2021-05-24 15:24:33 +01:00 коммит произвёл GitHub
Родитель 7392c718ae
Коммит f20f21813a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 109 добавлений и 44 удалений

Просмотреть файл

@ -5,84 +5,149 @@
#include <vector>
template<typename scalar_t> using tensor_accessor_1 =
torch::PackedTensorAccessor32<scalar_t,1,torch::RestrictPtrTraits>;
template<typename scalar_t> using tensor_accessor_2 =
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits>;
template <typename scalar_t>
__global__ void vrelu3_cuda_forward_kernel(
inline __device__ scalar_t relu3_forward(scalar_t input) {
if (input < (scalar_t)0.0) {
return (scalar_t)0.0;
} else if (input < (scalar_t)1.0) {
return (scalar_t)1/3 * input * input * input;
} else {
return input - (scalar_t)2/3;
}
}
template <typename scalar_t>
__global__ void vrelu3_cuda_forward_kernel_1(
const tensor_accessor_1<scalar_t> input,
tensor_accessor_1<scalar_t> output) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < input.size(0))
output[i] = relu3_forward(input[i]);
}
template <typename scalar_t>
__global__ void vrelu3_cuda_forward_kernel_2(
const tensor_accessor_2<scalar_t> input,
tensor_accessor_2<scalar_t> output) {
// element index
const int n = blockIdx.y;
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < input.size(1)){
auto input_elem = input[n][i];
if (input_elem >= (scalar_t)0.0) {
if (input_elem > (scalar_t)1.0) {
output[n][i] = input_elem - (scalar_t)2/3;
} else {
output[n][i] = (scalar_t)1/3 * input_elem * input_elem * input_elem;
}
}
}
if (i < input.size(1))
output[n][i] = relu3_forward(input[n][i]);
}
torch::Tensor vrelu3_cuda_forward(torch::Tensor input) {
auto output = torch::zeros_like(input);
// TODO: check rank of input (assumed to be rank 2 here)
const auto input_size_0 = input.size(0);
const auto input_size_1 = input.size(1);
switch (input.sizes().size()) {
case 1: {
const auto input_size = input.size(0);
// TODO: find out how PyTorch chooses these parameters
const int threads = 1024;
const dim3 blocks((input_size_1 + threads - 1) / threads, input_size_0);
// TODO: find out how PyTorch chooses these parameters
const int threads = 1024;
const int blocks = (input_size + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(input.type(), "relu3_forward_cuda", ([&] {
vrelu3_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
}));
AT_DISPATCH_FLOATING_TYPES(input.type(), "vrelu3_forward_cuda (rank 1)", ([&] {
vrelu3_cuda_forward_kernel_1<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t,1,torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t,1,torch::RestrictPtrTraits>());
}));
break;
}
case 2: {
const auto input_size_0 = input.size(0);
const auto input_size_1 = input.size(1);
const int threads = 1024;
const dim3 blocks((input_size_1 + threads - 1) / threads, input_size_0);
AT_DISPATCH_FLOATING_TYPES(input.type(), "vrelu3_forward_cuda (rank 2)", ([&] {
vrelu3_cuda_forward_kernel_2<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
}));
break;
}
default:
TORCH_CHECK(false, "Unsupported tensor rank");
}
return output;
}
template <typename scalar_t>
__global__ void vrelu3_cuda_backward_kernel(
inline __device__ scalar_t relu3_backward(scalar_t grad, scalar_t x) {
if (x < (scalar_t)0.0) {
return (scalar_t)0.0;
} else if (x < (scalar_t)1.0) {
return x * x * grad;
} else {
return grad;
}
}
template <typename scalar_t>
__global__ void vrelu3_cuda_backward_kernel_1(
tensor_accessor_1<scalar_t> d_x,
const tensor_accessor_1<scalar_t> grad,
const tensor_accessor_1<scalar_t> x) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < x.size(0))
d_x[i] = relu3_backward(grad[i], x[i]);
}
template <typename scalar_t>
__global__ void vrelu3_cuda_backward_kernel_2(
tensor_accessor_2<scalar_t> d_x,
const tensor_accessor_2<scalar_t> grad,
const tensor_accessor_2<scalar_t> x) {
const int n = blockIdx.y;
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < x.size(1)){
if (x[n][i] >= (scalar_t)0.0) {
if (x[n][i] > (scalar_t)1.0) {
d_x[n][i] = grad[n][i];
} else {
d_x[n][i] = x[n][i] * x[n][i] * grad[n][i];
}
}
}
if (i < x.size(1))
d_x[n][i] = relu3_backward(grad[n][i], x[n][i]);
}
torch::Tensor vrelu3_cuda_backward(
torch::Tensor grad,
torch::Tensor x) {
auto d_x = torch::zeros_like(x);
auto x_size_0 = x.size(0);
auto x_size_1 = x.size(1);
switch (x.sizes().size()) {
case 1: {
auto x_size = x.size(0);
// TODO: find out how PyTorch chooses these parameters
const int threads = 1024;
const dim3 blocks((x_size_1 + threads - 1) / threads, x_size_0);
// TODO: find out how PyTorch chooses these parameters
const int threads = 1024;
const int blocks = (x_size + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(x.type(), "relu3_backward_cuda", ([&] {
vrelu3_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_x.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
grad.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
x.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
}));
AT_DISPATCH_FLOATING_TYPES(x.type(), "vrelu3_backward_cuda (rank 1)", ([&] {
vrelu3_cuda_backward_kernel_1<scalar_t><<<blocks, threads>>>(
d_x.packed_accessor32<scalar_t,1,torch::RestrictPtrTraits>(),
grad.packed_accessor32<scalar_t,1,torch::RestrictPtrTraits>(),
x.packed_accessor32<scalar_t,1,torch::RestrictPtrTraits>());
}));
break;
}
case 2: {
auto x_size_0 = x.size(0);
auto x_size_1 = x.size(1);
const int threads = 1024;
const dim3 blocks((x_size_1 + threads - 1) / threads, x_size_0);
AT_DISPATCH_FLOATING_TYPES(x.type(), "vrelu3_backward_cuda (rank 2)", ([&] {
vrelu3_cuda_backward_kernel_2<scalar_t><<<blocks, threads>>>(
d_x.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
grad.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
x.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
}));
break;
}
default:
TORCH_CHECK(false, "Unsupported tensor rank");
}
return d_x;
}