Add and test element wise sign math funtions for CPU and GPU

This commit is contained in:
Kai Li 2014-02-25 19:16:44 +08:00
Родитель 910f3128c7
Коммит 348a338e7f
3 изменённых файлов: 62 добавлений и 4 удалений

Просмотреть файл

@ -119,6 +119,23 @@ Dtype caffe_cpu_asum(const int n, const Dtype* x);
template <typename Dtype>
void caffe_gpu_asum(const int n, const Dtype* x, Dtype* y);
// the branchless, type-safe version from
// http://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c
template<typename Dtype>
inline char caffe_sign(Dtype val) {
return (Dtype(0) < val) - (val < Dtype(0));
}
template<typename Dtype>
void caffe_cpu_sign(const int n, const Dtype* x, Dtype* y) {
for (int i = 0; i < n; ++i) {
y[i] = caffe_sign<Dtype>(x[i]);
}
}
template<typename Dtype>
void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);
} // namespace caffe

Просмотреть файл

@ -98,4 +98,25 @@ TYPED_TEST(MathFunctionsTest, TestAsumGPU){
CHECK_LT((gpu_asum - std_asum) / std_asum, 1e-2);
}
} // namespace caffe
TYPED_TEST(MathFunctionsTest, TestSignCPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_sign<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* signs = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0));
}
}
TYPED_TEST(MathFunctionsTest, TestSignGPU){
int n = this->blob_bottom_->count();
caffe_gpu_sign<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* signs = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0));
}
}
}

Просмотреть файл

@ -1,4 +1,5 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github
#include <cmath>
#include <cstdlib>
@ -33,5 +34,24 @@ void caffe_gpu_mul<double>(const int N, const double* a,
N, a, b, y);
}
template<typename Dtype>
__global__ void sign_kernel(const int n, const Dtype* x, Dtype* y) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0));
}
}
template <>
void caffe_gpu_sign<float>(const int n, const float* x, float* y) {
sign_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
n, x, y);
}
template <>
void caffe_gpu_sign<double>(const int n, const double* x, double* y) {
sign_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
n, x, y);
}
} // namespace caffe