Merge pull request #201 from kloudkl/more_math_functions

Add more convenience math functions and all tests pass
This commit is contained in:
Evan Shelhamer 2014-03-24 16:13:52 -07:00
Родитель d3e4c21d91 4d53804846
Коммит 91483aef03
4 изменённых файлов: 262 добавлений и 3 удалений

Просмотреть файл

@ -4,8 +4,9 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_
#include <cublas_v2.h>
#include <math.h> // for signbit
#include <cmath> // for std::fabs
#include "caffe/util/mkl_alternate.hpp"
@ -112,6 +113,88 @@ void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out);
template <typename Dtype>
int caffe_hamming_distance(const int n, const Dtype* x, const Dtype* y);
// Returns the sum of the absolute values of the elements of vector x
template <typename Dtype>
Dtype caffe_cpu_asum(const int n, const Dtype* x);
template <typename Dtype>
void caffe_gpu_asum(const int n, const Dtype* x, Dtype* y);
// the branchless, type-safe version from
// http://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c
template<typename Dtype>
inline char caffe_sign(Dtype val) {
return (Dtype(0) < val) - (val < Dtype(0));
}
// The following two macros are modifications of DEFINE_VSL_UNARY_FUNC
// in include/caffe/util/mkl_alternate.hpp authored by @Rowland Depp.
// Please refer to commit 7e8ef25c7 of the boost-eigen branch.
// Git cherry picking that commit caused a conflict hard to resolve and
// copying that file in convenient for code reviewing.
// So they have to be pasted here temporarily.
#define DEFINE_CAFFE_CPU_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void caffe_cpu_##name(const int n, const Dtype* x, Dtype* y) { \
CHECK_GT(n, 0); CHECK(x); CHECK(y); \
for (int i = 0; i < n; ++i) { \
operation; \
} \
}
#define INSTANTIATE_CAFFE_CPU_UNARY_FUNC(name) \
template <> \
void caffe_cpu_##name<float>(const int n, const float* x, float* y); \
template <> \
void caffe_cpu_##name<double>(const int n, const double* x, double* y)
#define DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(name, operation) \
template<typename Dtype> \
__global__ void name##_kernel(const int n, const Dtype* x, Dtype* y) { \
int index = threadIdx.x + blockIdx.x * blockDim.x; \
if (index < n) { \
operation; \
} \
} \
template <> \
void caffe_gpu_##name<float>(const int n, const float* x, float* y) { \
/* NOLINT_NEXT_LINE(whitespace/operators) */ \
name##_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>( \
n, x, y); \
} \
template <> \
void caffe_gpu_##name<double>(const int n, const double* x, double* y) { \
/* NOLINT_NEXT_LINE(whitespace/operators) */ \
name##_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>( \
n, x, y); \
}
// output is 1 for the positives, 0 for zero, and -1 for the negatives
DEFINE_CAFFE_CPU_UNARY_FUNC(sign, y[i] = caffe_sign<Dtype>(x[i]));
template<typename Dtype>
void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);
// This returns a nonzero value if the input has its sign bit set.
// The name sngbit is meant to avoid conflicts with std::signbit in the macro
using std::signbit;
DEFINE_CAFFE_CPU_UNARY_FUNC(sgnbit, y[i] = signbit(x[i]));
template<typename Dtype>
void caffe_gpu_sgnbit(const int n, const Dtype* x, Dtype* y);
DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i]));
template <typename Dtype>
void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y);
template <typename Dtype>
void caffe_cpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y);
template <typename Dtype>
void caffe_gpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y);
} // namespace caffe

Просмотреть файл

@ -1,6 +1,10 @@
// Copyright 2014 kloudkl@github
#include <stdint.h> // for uint32_t & uint64_t
#include <time.h>
#include <climits>
#include <cmath> // for std::fabs
#include <cstdlib> // for rand_r
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
@ -22,8 +26,8 @@ class MathFunctionsTest : public ::testing::Test {
virtual void SetUp() {
Caffe::set_random_seed(1701);
this->blob_bottom_->Reshape(100, 70, 50, 30);
this->blob_top_->Reshape(100, 70, 50, 30);
this->blob_bottom_->Reshape(11, 17, 19, 23);
this->blob_top_->Reshape(11, 17, 19, 23);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
@ -74,4 +78,118 @@ TYPED_TEST(MathFunctionsTest, TestHammingDistance) {
caffe_hamming_distance<TypeParam>(n, x, y));
}
TYPED_TEST(MathFunctionsTest, TestAsumCPU) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
TypeParam std_asum = 0;
for (int i = 0; i < n; ++i) {
std_asum += std::fabs(x[i]);
}
TypeParam cpu_asum = caffe_cpu_asum<TypeParam>(n, x);
CHECK_LT((cpu_asum - std_asum) / std_asum, 1e-2);
}
TYPED_TEST(MathFunctionsTest, TestAsumGPU) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
TypeParam std_asum = 0;
for (int i = 0; i < n; ++i) {
std_asum += std::fabs(x[i]);
}
TypeParam gpu_asum;
caffe_gpu_asum<TypeParam>(n, this->blob_bottom_->gpu_data(), &gpu_asum);
CHECK_LT((gpu_asum - std_asum) / std_asum, 1e-2);
}
TYPED_TEST(MathFunctionsTest, TestSignCPU) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_sign<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* signs = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0));
}
}
TYPED_TEST(MathFunctionsTest, TestSignGPU) {
int n = this->blob_bottom_->count();
caffe_gpu_sign<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* signs = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0));
}
}
TYPED_TEST(MathFunctionsTest, TestSgnbitCPU) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_sgnbit<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* signbits = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0);
}
}
TYPED_TEST(MathFunctionsTest, TestSgnbitGPU) {
int n = this->blob_bottom_->count();
caffe_gpu_sgnbit<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* signbits = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0);
}
}
TYPED_TEST(MathFunctionsTest, TestFabsCPU) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_fabs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
}
}
TYPED_TEST(MathFunctionsTest, TestFabsGPU) {
int n = this->blob_bottom_->count();
caffe_gpu_fabs<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
}
}
TYPED_TEST(MathFunctionsTest, TestScaleCPU) {
int n = this->blob_bottom_->count();
// NOLINT_NEXT_LINE(runtime/threadsafe_fn)
TypeParam alpha = this->blob_bottom_->cpu_diff()[rand() %
this->blob_bottom_->count()];
caffe_cpu_scale<TypeParam>(n, alpha, this->blob_bottom_->cpu_data(),
this->blob_bottom_->mutable_cpu_diff());
const TypeParam* scaled = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(scaled[i], x[i] * alpha);
}
}
TYPED_TEST(MathFunctionsTest, TestScaleGPU) {
int n = this->blob_bottom_->count();
// NOLINT_NEXT_LINE(runtime/threadsafe_fn)
TypeParam alpha = this->blob_bottom_->cpu_diff()[rand() %
this->blob_bottom_->count()];
caffe_gpu_scale<TypeParam>(n, alpha, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* scaled = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(scaled[i], x[i] * alpha);
}
}
} // namespace caffe

Просмотреть файл

@ -390,4 +390,56 @@ int caffe_hamming_distance<double>(const int n, const double* x,
return dist;
}
template <>
float caffe_cpu_asum<float>(const int n, const float* x) {
return cblas_sasum(n, x, 1);
}
template <>
double caffe_cpu_asum<double>(const int n, const double* x) {
return cblas_dasum(n, x, 1);
}
template <>
void caffe_gpu_asum<float>(const int n, const float* x, float* y) {
CUBLAS_CHECK(cublasSasum(Caffe::cublas_handle(), n, x, 1, y));
}
template <>
void caffe_gpu_asum<double>(const int n, const double* x, double* y) {
CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y));
}
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign);
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sgnbit);
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs);
template <>
void caffe_cpu_scale<float>(const int n, const float alpha, const float *x,
float* y) {
cblas_scopy(n, x, 1, y, 1);
cblas_sscal(n, alpha, y, 1);
}
template <>
void caffe_cpu_scale<double>(const int n, const double alpha, const double *x,
double* y) {
cblas_dcopy(n, x, 1, y, 1);
cblas_dscal(n, alpha, y, 1);
}
template <>
void caffe_gpu_scale<float>(const int n, const float alpha, const float *x,
float* y) {
CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1));
CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1));
}
template <>
void caffe_gpu_scale<double>(const int n, const double alpha, const double *x,
double* y) {
CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1));
CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1));
}
} // namespace caffe

Просмотреть файл

@ -1,5 +1,7 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github
#include <math_functions.h> // CUDA's, not caffe's, for fabs, signbit
#include <cmath>
#include <cstdlib>
#include <cstring>
@ -33,5 +35,9 @@ void caffe_gpu_mul<double>(const int N, const double* a,
N, a, b, y);
}
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index])
- (x[index] < Dtype(0)));
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sgnbit, y[index] = signbit(x[index]));
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(fabs, y[index] = fabs(x[index]));
} // namespace caffe