Merge pull request #165 from BVLC/boost-eigen

MKL/non-MKL Reconciliation

Caffe no longer requires MKL. By default it builds without it, relying on atlas and cblas instead. Set the `USE_MKL` var in your Makefile.config accordingly.
This commit is contained in:
Evan Shelhamer 2014-03-22 22:53:42 -07:00
Родитель 510b3c028f bece205114
Коммит 699b557c75
21 изменённых файлов: 526 добавлений и 176 удалений

Просмотреть файл

@ -86,27 +86,37 @@ CUDA_LIB_DIR := $(CUDA_DIR)/lib64 $(CUDA_DIR)/lib
MKL_INCLUDE_DIR := $(MKL_DIR)/include
MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64
INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR)
LIBRARIES := cudart cublas curand \
mkl_rt \
pthread \
glog protobuf leveldb \
snappy \
glog protobuf leveldb snappy \
boost_system \
hdf5_hl hdf5 \
opencv_core opencv_highgui opencv_imgproc
PYTHON_LIBRARIES := boost_python python2.7
WARNINGS := -Wall
COMMON_FLAGS := -DNDEBUG -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
COMMON_FLAGS := -DNDEBUG -O2
# MKL switch (default = non-MKL)
USE_MKL ?= 0
ifeq ($(USE_MKL), 1)
LIBRARIES += mkl_rt
COMMON_FLAGS += -DUSE_MKL
INCLUDE_DIRS += $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(MKL_LIB_DIR)
else
LIBRARIES += cblas atlas
endif
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS)
NVCCFLAGS := -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
$(foreach library,$(LIBRARIES),-l$(library))
PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library))
##############################
# Define build targets
##############################

Просмотреть файл

@ -10,6 +10,8 @@ CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \
-gencode arch=compute_30,code=sm_30 \
-gencode arch=compute_35,code=sm_35
# MKL switch: set to 1 for MKL
USE_MKL := 0
# MKL directory contains include/ and lib/ directions that we need.
MKL_DIR := /opt/intel/mkl

Просмотреть файл

@ -27,6 +27,14 @@ class Blob {
inline int count() const {return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
// Copy from source. If copy_diff is false, we copy the data; if copy_diff

Просмотреть файл

@ -1,4 +1,4 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 BVLC and contributors.
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
@ -7,28 +7,8 @@
#include <cublas_v2.h>
#include <cuda.h>
#include <curand.h>
// cuda driver types
#include <driver_types.h>
#include <driver_types.h> // cuda driver types
#include <glog/logging.h>
#include <mkl_vsl.h>
// various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// After a kernel is executed, this will check the error and if there is one,
// exit loudly.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())
// Disable the copy and assignment operator for a class.
#define DISABLE_COPY_AND_ASSIGN(classname) \
@ -45,6 +25,23 @@ private:\
// is executed we will see a fatal log.
#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
// CUDA: various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: check for error after kernel execution and exit loudly if there is one.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())
namespace caffe {
@ -53,20 +50,6 @@ namespace caffe {
using boost::shared_ptr;
// We will use 1024 threads per block, which requires cuda sm_2x or above.
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
@ -81,15 +64,32 @@ class Caffe {
enum Brew { CPU, GPU };
enum Phase { TRAIN, TEST };
// The getters for the variables.
// Returns the cublas handle.
// This random number generator facade hides boost and CUDA rng
// implementation from one another (for cross-platform compatibility).
class RNG {
public:
RNG();
explicit RNG(unsigned int seed);
~RNG();
RNG(const RNG&);
RNG& operator=(const RNG&);
const void* generator() const;
void* generator();
private:
class Generator;
Generator* generator_;
};
// Getters for boost rng, curand, and cublas handles
inline static RNG &rng_stream() {
return Get().random_generator_;
}
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
// Returns the curand generator.
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
// Returns the MKL random stream.
inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
@ -102,7 +102,7 @@ class Caffe {
inline static void set_mode(Brew mode) { Get().mode_ = mode; }
// Sets the phase.
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both MKL and curand
// Sets the random seed of both boost and curand
static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
@ -113,7 +113,8 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
RNG random_generator_;
Brew mode_;
Phase phase_;
static shared_ptr<Caffe> singleton_;
@ -126,6 +127,21 @@ class Caffe {
};
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif
// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
} // namespace caffe
#endif // CAFFE_COMMON_HPP_

Просмотреть файл

@ -7,7 +7,6 @@
#ifndef CAFFE_FILLER_HPP
#define CAFFE_FILLER_HPP
#include <mkl.h>
#include <string>
#include "caffe/common.hpp"

Просмотреть файл

@ -4,9 +4,11 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_
#include <mkl.h>
#include <cublas_v2.h>
#include "caffe/util/mkl_alternate.hpp"
namespace caffe {
// Decaf gemm provides a simpler interface to the gemm functions, with the
@ -45,7 +47,7 @@ void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);
template <typename Dtype>
void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
void caffe_cpu_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);
template <typename Dtype>
@ -85,6 +87,9 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);
template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);
@ -92,6 +97,9 @@ template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);
template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p);
template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);

Просмотреть файл

@ -0,0 +1,97 @@
// Copyright 2013 Rowland Depp
#ifndef CAFFE_UTIL_MKL_ALTERNATE_H_
#define CAFFE_UTIL_MKL_ALTERNATE_H_
#ifdef USE_MKL
#include <mkl.h>
#else // If use MKL, simply include the MKL header
extern "C" {
#include <cblas.h>
}
#include <math.h>
// Functions that caffe uses but are not present if MKL is not linked.
// A simple way to define the vsl unary functions. The operation should
// be in the form e.g. y[i] = sqrt(a[i])
#define DEFINE_VSL_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, float* y) { \
v##name<float>(n, a, y); \
} \
inline void vd##name( \
const int n, const double* a, double* y) { \
v##name<double>(n, a, y); \
}
DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
// A simple way to define the vsl unary functions with singular parameter b.
// The operation should be in the form e.g. y[i] = pow(a[i], b)
#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const float b, double* y) { \
v##name<double>(n, a, b, y); \
}
DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b));
// A simple way to define the vsl binary functions. The operation should
// be in the form e.g. y[i] = a[i] + b[i]
#define DEFINE_VSL_BINARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float* b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const double* b, double* y) { \
v##name<double>(n, a, b, y); \
}
DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]);
DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]);
DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]);
DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]);
// In addition, MKL comes with an additional function axpby that is not present
// in standard blas. We will simply use a two-step (inefficient, of course) way
// to mimic that.
inline void cblas_saxpby(const int N, const float alpha, const float* X,
const int incX, const float beta, float* Y,
const int incY) {
cblas_sscal(N, beta, Y, incY);
cblas_saxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_daxpby(const int N, const double alpha, const double* X,
const int incX, const double beta, double* Y,
const int incY) {
cblas_dscal(N, beta, Y, incY);
cblas_daxpy(N, alpha, X, incX, Y, incY);
}
#endif // USE_MKL
#endif // CAFFE_UTIL_MKL_ALTERNATE_H_

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright 2014 BVLC and contributors.
#ifndef CAFFE_RNG_CPP_HPP_
#define CAFFE_RNG_CPP_HPP_
#include <boost/random/mersenne_twister.hpp>
#include "caffe/common.hpp"
namespace caffe {
typedef boost::mt19937 rng_t;
inline rng_t& caffe_rng() {
Caffe::RNG &generator = Caffe::rng_stream();
return *(caffe::rng_t*) generator.generator();
}
} // namespace caffe
#endif // CAFFE_RNG_HPP_

Просмотреть файл

@ -1,15 +1,17 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 BVLC and contributors.
#include <cstdio>
#include <ctime>
#include "caffe/common.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
shared_ptr<Caffe> Caffe::singleton_;
// curand seeding
int64_t cluster_seedgen(void) {
int64_t s, seed, pid;
pid = getpid();
@ -21,7 +23,8 @@ int64_t cluster_seedgen(void) {
Caffe::Caffe()
: mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
curand_generator_(NULL), vsl_stream_(NULL) {
curand_generator_(NULL),
random_generator_() {
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
@ -34,13 +37,6 @@ Caffe::Caffe()
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}
// Try to create a vsl stream. This should almost always work, but we will
// check it anyway.
if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937,
cluster_seedgen()) != VSL_STATUS_OK) {
LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
<< "won't be available.";
}
}
Caffe::~Caffe() {
@ -48,7 +44,6 @@ Caffe::~Caffe() {
if (curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
}
void Caffe::set_random_seed(const unsigned int seed) {
@ -64,9 +59,8 @@ void Caffe::set_random_seed(const unsigned int seed) {
} else {
LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
}
// VSL seed
VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
// RNG seed
Get().random_generator_ = RNG(seed);
}
void Caffe::SetDevice(const int device_id) {
@ -120,4 +114,37 @@ void Caffe::DeviceQuery() {
return;
}
class Caffe::RNG::Generator {
public:
caffe::rng_t rng;
};
Caffe::RNG::RNG()
: generator_(new Generator) { }
Caffe::RNG::RNG(unsigned int seed)
: generator_(new Generator) {
generator_->rng = caffe::rng_t(seed);
}
Caffe::RNG::~RNG() { delete generator_; }
Caffe::RNG::RNG(const RNG& other) : generator_(new Generator) {
*generator_ = *other.generator_;
}
Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
*generator_ = *other.generator_;
return *this;
}
void* Caffe::RNG::generator() {
return &generator_->rng;
}
const void* Caffe::RNG::generator() const {
return &generator_->rng;
}
} // namespace caffe

Просмотреть файл

@ -3,6 +3,7 @@
#include <vector>
#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/layer.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/vision_layers.hpp"
@ -31,8 +32,7 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int count = bottom[0]->count();
if (Caffe::phase() == Caffe::TRAIN) {
// Create random numbers
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
count, mask, 1. - threshold_);
caffe_vRngBernoulli<int>(count, mask, 1. - threshold_);
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
}

Просмотреть файл

@ -1,8 +1,5 @@
// Copyright 2013 Yangqing Jia
#include <mkl.h>
#include <vector>
#include "caffe/blob.hpp"

Просмотреть файл

@ -1,7 +1,5 @@
// Copyright 2013 Yangqing Jia
#include <mkl.h>
#include <cublas_v2.h>
#include <vector>

Просмотреть файл

@ -154,7 +154,7 @@ void EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
int count = (*bottom)[0]->count();
int num = (*bottom)[0]->num();
// Compute the gradient
caffe_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
caffe_cpu_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
(*bottom)[0]->mutable_cpu_diff());
}

Просмотреть файл

@ -215,7 +215,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
caffe_axpby(net_params[param_id]->count(), local_rate,
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
if (local_decay) {

Просмотреть файл

@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
@ -19,10 +19,6 @@ TEST_F(CommonTest, TestCublasHandler) {
EXPECT_TRUE(Caffe::cublas_handle());
}
TEST_F(CommonTest, TestVslStream) {
EXPECT_TRUE(Caffe::vsl_stream());
}
TEST_F(CommonTest, TestBrewMode) {
Caffe::set_mode(Caffe::CPU);
EXPECT_EQ(Caffe::mode(), Caffe::CPU);
@ -40,18 +36,19 @@ TEST_F(CommonTest, TestRandSeedCPU) {
SyncedMemory data_a(10 * sizeof(int));
SyncedMemory data_b(10 * sizeof(int));
Caffe::set_random_seed(1701);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
10, reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
caffe_vRngBernoulli(10,
reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
Caffe::set_random_seed(1701);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
10, reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
caffe_vRngBernoulli(10,
reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
((const int*)(data_b.cpu_data()))[i]);
}
}
TEST_F(CommonTest, TestRandSeedGPU) {
SyncedMemory data_a(10 * sizeof(unsigned int));
SyncedMemory data_b(10 * sizeof(unsigned int));
@ -67,5 +64,4 @@ TEST_F(CommonTest, TestRandSeedGPU) {
}
}
} // namespace caffe

Просмотреть файл

@ -23,6 +23,7 @@ class FlattenLayerTest : public ::testing::Test {
FlattenLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
blob_top_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
@ -73,6 +74,8 @@ TYPED_TEST(FlattenLayerTest, TestGPU) {
for (int c = 0; c < 3 * 6 * 5; ++c) {
EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5));
EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0),
this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5));
}
}

Просмотреть файл

@ -25,6 +25,7 @@ class MultinomialLogisticLossLayerTest : public ::testing::Test {
MultinomialLogisticLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
PositiveUnitballFiller<Dtype> filler(filler_param);
@ -55,7 +56,7 @@ TYPED_TEST(MultinomialLogisticLossLayerTest, TestGradientCPU) {
Caffe::set_mode(Caffe::CPU);
MultinomialLogisticLossLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701, 0, 0.05);
GradientChecker<TypeParam> checker(1e-2, 2*1e-2, 1701, 0, 0.05);
checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_), 0, -1, -1);
}

Просмотреть файл

@ -0,0 +1,98 @@
// Copyright 2014 BVLC and contributors.
#include <cuda_runtime.h>
#include <cmath>
#include <cstring>
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
template <typename Dtype>
class RandomNumberGeneratorTest : public ::testing::Test {
public:
virtual ~RandomNumberGeneratorTest() {}
Dtype sample_mean(const Dtype* const seqs, const size_t sample_size) {
double sum = 0;
for (int i = 0; i < sample_size; ++i) {
sum += seqs[i];
}
return sum / sample_size;
}
Dtype sample_mean(const int* const seqs, const size_t sample_size) {
Dtype sum = 0;
for (int i = 0; i < sample_size; ++i) {
sum += Dtype(seqs[i]);
}
return sum / sample_size;
}
Dtype mean_bound(const Dtype std, const size_t sample_size) {
return std/sqrt(static_cast<double>(sample_size));
}
};
typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(RandomNumberGeneratorTest, Dtypes);
TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussian) {
size_t sample_size = 10000;
SyncedMemory data_a(sample_size * sizeof(TypeParam));
Caffe::set_random_seed(1701);
TypeParam mu = 0;
TypeParam sigma = 1;
caffe_vRngGaussian(sample_size,
reinterpret_cast<TypeParam*>(data_a.mutable_cpu_data()), mu, sigma);
TypeParam true_mean = mu;
TypeParam true_std = sigma;
TypeParam bound = this->mean_bound(true_std, sample_size);
TypeParam empirical_mean =
this->sample_mean(reinterpret_cast<const TypeParam*>(data_a.cpu_data()),
sample_size);
EXPECT_NEAR(empirical_mean, true_mean, bound);
}
TYPED_TEST(RandomNumberGeneratorTest, TestRngUniform) {
size_t sample_size = 10000;
SyncedMemory data_a(sample_size * sizeof(TypeParam));
Caffe::set_random_seed(1701);
TypeParam lower = 0;
TypeParam upper = 1;
caffe_vRngUniform(sample_size,
reinterpret_cast<TypeParam*>(data_a.mutable_cpu_data()), lower, upper);
TypeParam true_mean = (lower + upper) / 2;
TypeParam true_std = (upper - lower) / sqrt(12);
TypeParam bound = this->mean_bound(true_std, sample_size);
TypeParam empirical_mean =
this->sample_mean(reinterpret_cast<const TypeParam*>(data_a.cpu_data()),
sample_size);
EXPECT_NEAR(empirical_mean, true_mean, bound);
}
TYPED_TEST(RandomNumberGeneratorTest, TestRngBernoulli) {
size_t sample_size = 10000;
SyncedMemory data_a(sample_size * sizeof(int));
Caffe::set_random_seed(1701);
double p = 0.3;
caffe_vRngBernoulli(sample_size,
static_cast<int*>(data_a.mutable_cpu_data()), p);
TypeParam true_mean = p;
TypeParam true_std = sqrt(p * (1 - p));
TypeParam bound = this->mean_bound(true_std, sample_size);
TypeParam empirical_mean =
this->sample_mean((const int *)data_a.cpu_data(), sample_size);
EXPECT_NEAR(empirical_mean, true_mean, bound);
}
} // namespace caffe

Просмотреть файл

@ -146,8 +146,6 @@ TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) {
}
}
TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) {
Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TRAIN);
@ -157,7 +155,7 @@ TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) {
layer_param.set_pool(LayerParameter_PoolMethod_STOCHASTIC);
PoolingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
GradientChecker<TypeParam> checker(1e-4, 1e-2);
// it is too expensive to call curand multiple times, so we don't do an
// exhaustive gradient check.
checker.CheckGradient(&layer, &(this->blob_bottom_vec_),

Просмотреть файл

@ -3,7 +3,6 @@
#include <cstring>
#include "cuda_runtime.h"
#include "mkl.h"
#include "cublas_v2.h"
#include "gtest/gtest.h"

Просмотреть файл

@ -1,10 +1,14 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github
// Copyright 2014 BVLC and contributors.
#include <mkl.h>
#include <boost/math/special_functions/next.hpp>
#include <boost/random.hpp>
#include <cublas_v2.h>
#include <limits>
#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
@ -104,7 +108,6 @@ template <>
void caffe_axpy<double>(const int N, const double alpha, const double* X,
double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
template <>
void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
float* Y) {
@ -117,18 +120,6 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
void caffe_axpby<float>(const int N, const float alpha, const float* X,
const float beta, float* Y) {
cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
}
template <>
void caffe_axpby<double>(const int N, const double alpha, const double* X,
const double beta, double* Y) {
cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
}
template <>
void caffe_copy<float>(const int N, const float* X, float* Y) {
cblas_scopy(N, X, 1, Y, 1);
@ -183,6 +174,78 @@ void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
caffe_gpu_axpy<double>(N, alpha, X, Y);
}
template <>
void caffe_cpu_axpby<float>(const int N, const float alpha, const float* X,
const float beta, float* Y) {
cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
}
template <>
void caffe_cpu_axpby<double>(const int N, const double alpha, const double* X,
const double beta, double* Y) {
cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
}
template <>
void caffe_add<float>(const int n, const float* a, const float* b,
float* y) {
vsAdd(n, a, b, y);
}
template <>
void caffe_add<double>(const int n, const double* a, const double* b,
double* y) {
vdAdd(n, a, b, y);
}
template <>
void caffe_sub<float>(const int n, const float* a, const float* b,
float* y) {
vsSub(n, a, b, y);
}
template <>
void caffe_sub<double>(const int n, const double* a, const double* b,
double* y) {
vdSub(n, a, b, y);
}
template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
float* y) {
vsMul(n, a, b, y);
}
template <>
void caffe_mul<double>(const int n, const double* a, const double* b,
double* y) {
vdMul(n, a, b, y);
}
template <>
void caffe_div<float>(const int n, const float* a, const float* b,
float* y) {
vsDiv(n, a, b, y);
}
template <>
void caffe_div<double>(const int n, const double* a, const double* b,
double* y) {
vdDiv(n, a, b, y);
}
template <>
void caffe_powx<float>(const int n, const float* a, const float b,
float* y) {
vsPowx(n, a, b, y);
}
template <>
void caffe_powx<double>(const int n, const double* a, const double b,
double* y) {
vdPowx(n, a, b, y);
}
template <>
void caffe_sqr<float>(const int n, const float* a, float* y) {
vsSqr(n, a, y);
@ -193,75 +256,6 @@ void caffe_sqr<double>(const int n, const double* a, double* y) {
vdSqr(n, a, y);
}
template <>
void caffe_add<float>(const int n, const float* a, const float* b,
float* y) { vsAdd(n, a, b, y); }
template <>
void caffe_add<double>(const int n, const double* a, const double* b,
double* y) { vdAdd(n, a, b, y); }
template <>
void caffe_sub<float>(const int n, const float* a, const float* b,
float* y) { vsSub(n, a, b, y); }
template <>
void caffe_sub<double>(const int n, const double* a, const double* b,
double* y) { vdSub(n, a, b, y); }
template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
float* y) { vsMul(n, a, b, y); }
template <>
void caffe_mul<double>(const int n, const double* a, const double* b,
double* y) { vdMul(n, a, b, y); }
template <>
void caffe_div<float>(const int n, const float* a, const float* b,
float* y) { vsDiv(n, a, b, y); }
template <>
void caffe_div<double>(const int n, const double* a, const double* b,
double* y) { vdDiv(n, a, b, y); }
template <>
void caffe_powx<float>(const int n, const float* a, const float b,
float* y) { vsPowx(n, a, b, y); }
template <>
void caffe_powx<double>(const int n, const double* a, const double b,
double* y) { vdPowx(n, a, b, y); }
template <>
void caffe_vRngUniform<float>(const int n, float* r,
const float a, const float b) {
VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
n, r, a, b));
}
template <>
void caffe_vRngUniform<double>(const int n, double* r,
const double a, const double b) {
VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
n, r, a, b));
}
template <>
void caffe_vRngGaussian<float>(const int n, float* r, const float a,
const float sigma) {
VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
Caffe::vsl_stream(), n, r, a, sigma));
}
template <>
void caffe_vRngGaussian<double>(const int n, double* r, const double a,
const double sigma) {
VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
Caffe::vsl_stream(), n, r, a, sigma));
}
template <>
void caffe_exp<float>(const int n, const float* a, float* y) {
vsExp(n, a, y);
@ -272,6 +266,86 @@ void caffe_exp<double>(const int n, const double* a, double* y) {
vdExp(n, a, y);
}
template <typename Dtype>
Dtype caffe_nextafter(const Dtype b) {
return boost::math::nextafter<Dtype>(
b, std::numeric_limits<Dtype>::max());
}
template
float caffe_nextafter(const float b);
template
double caffe_nextafter(const double b);
template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r,
const Dtype a, const Dtype b) {
CHECK_GE(n, 0);
CHECK(r);
CHECK_LE(a, b);
boost::uniform_real<Dtype> random_distribution(
a, caffe_nextafter<Dtype>(b));
boost::variate_generator<caffe::rng_t,
boost::uniform_real<Dtype> > variate_generator(
caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
}
}
template
void caffe_vRngUniform<float>(const int n, float* r,
const float a, const float b);
template
void caffe_vRngUniform<double>(const int n, double* r,
const double a, const double b);
template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma) {
CHECK_GE(n, 0);
CHECK(r);
CHECK_GT(sigma, 0);
boost::normal_distribution<Dtype> random_distribution(a, sigma);
boost::variate_generator<caffe::rng_t,
boost::normal_distribution<Dtype> > variate_generator(
caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
}
}
template
void caffe_vRngGaussian<float>(const int n, float* r, const float a,
const float sigma);
template
void caffe_vRngGaussian<double>(const int n, double* r, const double a,
const double sigma);
template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p) {
CHECK_GE(n, 0);
CHECK(r);
CHECK_GE(p, 0);
CHECK_LE(p, 1);
boost::bernoulli_distribution<double> random_distribution(p);
boost::variate_generator<caffe::rng_t,
boost::bernoulli_distribution<double> > variate_generator(
caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
}
}
template
void caffe_vRngBernoulli<int>(const int n, int* r, const double p);
template <>
float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {
return cblas_sdot(n, x, 1, y, 1);