зеркало из https://github.com/microsoft/caffe.git
Add and test element wise abs math functions for CPU and GPU
This commit is contained in:
Родитель
f634899f44
Коммит
ccae3fa587
|
@ -4,7 +4,7 @@
|
|||
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
|
||||
#define CAFFE_UTIL_MATH_FUNCTIONS_H_
|
||||
|
||||
|
||||
#include <cmath> // for std::fabs
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include "caffe/util/mkl_alternate.hpp"
|
||||
|
@ -136,6 +136,16 @@ void caffe_cpu_sign(const int n, const Dtype* x, Dtype* y) {
|
|||
template<typename Dtype>
|
||||
void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);
|
||||
|
||||
template <typename Dtype>
|
||||
void caffe_cpu_fabs(const int n, const Dtype* x, Dtype* y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = std::fabs(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y);
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
|
||||
|
|
|
@ -119,4 +119,25 @@ TYPED_TEST(MathFunctionsTest, TestSignGPU){
|
|||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(MathFunctionsTest, TestFabsCPU){
|
||||
int n = this->blob_bottom_->count();
|
||||
const TypeParam* x = this->blob_bottom_->cpu_data();
|
||||
caffe_cpu_fabs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
|
||||
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
|
||||
for (int i = 0; i < n; ++i) {
|
||||
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(MathFunctionsTest, TestFabsGPU){
|
||||
int n = this->blob_bottom_->count();
|
||||
caffe_gpu_fabs<TypeParam>(n, this->blob_bottom_->gpu_data(),
|
||||
this->blob_bottom_->mutable_gpu_diff());
|
||||
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
|
||||
const TypeParam* x = this->blob_bottom_->cpu_data();
|
||||
for (int i = 0; i < n; ++i) {
|
||||
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -416,4 +416,10 @@ void caffe_cpu_sign<float>(const int n, const float* x, float* y);
|
|||
template <>
|
||||
void caffe_cpu_sign<double>(const int n, const double* x, double* y);
|
||||
|
||||
template <>
|
||||
void caffe_cpu_fabs<float>(const int n, const float* x, float* y);
|
||||
|
||||
template <>
|
||||
void caffe_cpu_fabs<double>(const int n, const double* x, double* y);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <math_functions.h> // CUDA's, not caffe's, for fabs
|
||||
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/util/math_functions.hpp"
|
||||
|
@ -54,4 +55,24 @@ void caffe_gpu_sign<double>(const int n, const double* x, double* y) {
|
|||
n, x, y);
|
||||
}
|
||||
|
||||
template<typename Dtype>
|
||||
__global__ void fabs_kernel(const int n, const Dtype* x, Dtype* y) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < n) {
|
||||
y[index] = fabs(x[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void caffe_gpu_fabs<float>(const int n, const float* x, float* y) {
|
||||
fabs_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
|
||||
n, x, y);
|
||||
}
|
||||
|
||||
template <>
|
||||
void caffe_gpu_fabs<double>(const int n, const double* x, double* y) {
|
||||
fabs_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
|
||||
n, x, y);
|
||||
}
|
||||
|
||||
} // namespace caffe
|
||||
|
|
Загрузка…
Ссылка в новой задаче