Add and test element wise abs math functions for CPU and GPU

This commit is contained in:
Kai Li 2014-02-25 19:33:10 +08:00
Родитель f634899f44
Коммит ccae3fa587
4 изменённых файлов: 59 добавлений и 1 удалений

Просмотреть файл

@ -4,7 +4,7 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_
#include <cmath> // for std::fabs
#include <cublas_v2.h>
#include "caffe/util/mkl_alternate.hpp"
@ -136,6 +136,16 @@ void caffe_cpu_sign(const int n, const Dtype* x, Dtype* y) {
template<typename Dtype>
void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);
template <typename Dtype>
void caffe_cpu_fabs(const int n, const Dtype* x, Dtype* y) {
for (int i = 0; i < n; ++i) {
y[i] = std::fabs(x[i]);
}
}
template <typename Dtype>
void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y);
} // namespace caffe

Просмотреть файл

@ -119,4 +119,25 @@ TYPED_TEST(MathFunctionsTest, TestSignGPU){
}
}
TYPED_TEST(MathFunctionsTest, TestFabsCPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_fabs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
}
}
TYPED_TEST(MathFunctionsTest, TestFabsGPU){
int n = this->blob_bottom_->count();
caffe_gpu_fabs<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
}
}
}

Просмотреть файл

@ -416,4 +416,10 @@ void caffe_cpu_sign<float>(const int n, const float* x, float* y);
template <>
void caffe_cpu_sign<double>(const int n, const double* x, double* y);
template <>
void caffe_cpu_fabs<float>(const int n, const float* x, float* y);
template <>
void caffe_cpu_fabs<double>(const int n, const double* x, double* y);
} // namespace caffe

Просмотреть файл

@ -4,6 +4,7 @@
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <math_functions.h> // CUDA's, not caffe's, for fabs
#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
@ -54,4 +55,24 @@ void caffe_gpu_sign<double>(const int n, const double* x, double* y) {
n, x, y);
}
template<typename Dtype>
__global__ void fabs_kernel(const int n, const Dtype* x, Dtype* y) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
y[index] = fabs(x[index]);
}
}
template <>
void caffe_gpu_fabs<float>(const int n, const float* x, float* y) {
fabs_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
n, x, y);
}
template <>
void caffe_gpu_fabs<double>(const int n, const double* x, double* y) {
fabs_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
n, x, y);
}
} // namespace caffe