Hide boost rng behind facade for osx compatibility

Split boost random number generation from the common Caffe singleton and
add a helper function for rng. This resolves a build conflict in OSX
between boost rng and nvcc compilation of cuda code.

Refer to #165 for a full discussion.

Thanks to @satol for suggesting a random number generation facade rather
than a total split of cpp and cu code, which is far more involved.
This commit is contained in:
Evan Shelhamer 2014-03-21 23:47:01 -07:00
Родитель aaa26466eb
Коммит 19bcf2b29b
4 изменённых файлов: 120 добавлений и 51 удалений

Просмотреть файл

@ -1,9 +1,9 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 Evan Shelhamer
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
#include <boost/random/mersenne_twister.hpp>
#include <boost/shared_ptr.hpp>
#include <cublas_v2.h>
#include <cuda.h>
@ -11,23 +11,6 @@
#include <driver_types.h> // cuda driver types
#include <glog/logging.h>
// various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// After a kernel is executed, this will check the error and if there is one,
// exit loudly.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())
// Disable the copy and assignment operator for a class.
#define DISABLE_COPY_AND_ASSIGN(classname) \
private:\
@ -43,6 +26,24 @@ private:\
// is executed we will see a fatal log.
#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
// CUDA: various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: check for error after kernel execution and exit loudly if there is one.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())
namespace caffe {
// We will use the boost shared_ptr instead of the new C++11 one mainly
@ -50,19 +51,6 @@ namespace caffe {
using boost::shared_ptr;
// We will use 1024 threads per block, which requires cuda sm_2x or above.
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
@ -77,20 +65,32 @@ class Caffe {
enum Brew { CPU, GPU };
enum Phase { TRAIN, TEST };
// The getters for the variables.
// Returns the cublas handle.
// This random number generator facade hides boost and CUDA rng
// implementation from one another (for cross-platform compatibility).
class RNG {
public:
RNG();
explicit RNG(unsigned int seed);
~RNG();
RNG(const RNG&);
RNG& operator=(const RNG&);
const void* generator() const;
void* generator();
private:
class Generator;
Generator* generator_;
};
// Getters for boost rng, curand, and cublas handles
inline static RNG &rng_stream() {
return Get().random_generator_;
}
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
// Returns the curand generator.
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
// boost RNG
typedef boost::mt19937 random_generator_t;
inline static random_generator_t &rng_stream() {
return Get().random_generator_;
}
// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
@ -114,7 +114,7 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
random_generator_t random_generator_;
RNG random_generator_;
Brew mode_;
Phase phase_;
@ -128,6 +128,21 @@ class Caffe {
};
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif
// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
} // namespace caffe
#endif // CAFFE_COMMON_HPP_

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright 2014 Evan Shelhamer
#ifndef CAFFE_RNG_CPP_HPP_
#define CAFFE_RNG_CPP_HPP_
#include <boost/random/mersenne_twister.hpp>
#include "caffe/common.hpp"
namespace caffe {
typedef boost::mt19937 rng_t;
inline rng_t& caffe_rng() {
Caffe::RNG &generator = Caffe::rng_stream();
return *(caffe::rng_t*) generator.generator();
}
} // namespace caffe
#endif // CAFFE_RNG_HPP_

Просмотреть файл

@ -1,15 +1,18 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 Evan Shelhamer
#include <cstdio>
#include <ctime>
#include "caffe/common.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
shared_ptr<Caffe> Caffe::singleton_;
// curand seeding
int64_t cluster_seedgen(void) {
int64_t s, seed, pid;
pid = getpid();
@ -58,7 +61,7 @@ void Caffe::set_random_seed(const unsigned int seed) {
LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
}
// RNG seed
Get().random_generator_ = random_generator_t(seed);
Get().random_generator_ = RNG(seed);
}
void Caffe::SetDevice(const int device_id) {
@ -112,4 +115,37 @@ void Caffe::DeviceQuery() {
return;
}
class Caffe::RNG::Generator {
public:
caffe::rng_t rng;
};
Caffe::RNG::RNG()
: generator_(new Generator) { }
Caffe::RNG::RNG(unsigned int seed)
: generator_(new Generator) {
generator_->rng = caffe::rng_t(seed);
}
Caffe::RNG::~RNG() { delete generator_; }
Caffe::RNG::RNG(const RNG& other) : generator_(new Generator) {
*generator_ = *other.generator_;
}
Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
*generator_ = *other.generator_;
return *this;
}
void* Caffe::RNG::generator() {
return &generator_->rng;
}
const void* Caffe::RNG::generator() const {
return &generator_->rng;
}
} // namespace caffe

Просмотреть файл

@ -1,5 +1,6 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github
// Copyright 2014 Evan Shelhamer
#include <boost/math/special_functions/next.hpp>
#include <boost/random.hpp>
@ -9,6 +10,7 @@
#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
@ -287,10 +289,9 @@ void caffe_vRngUniform(const int n, Dtype* r,
boost::uniform_real<Dtype> random_distribution(
a, caffe_nextafter<Dtype>(b));
Caffe::random_generator_t &generator = Caffe::rng_stream();
boost::variate_generator<Caffe::random_generator_t,
boost::variate_generator<caffe::rng_t,
boost::uniform_real<Dtype> > variate_generator(
generator, random_distribution);
caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
@ -311,10 +312,9 @@ void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
CHECK(r);
CHECK_GT(sigma, 0);
boost::normal_distribution<Dtype> random_distribution(a, sigma);
Caffe::random_generator_t &generator = Caffe::rng_stream();
boost::variate_generator<Caffe::random_generator_t,
boost::variate_generator<caffe::rng_t,
boost::normal_distribution<Dtype> > variate_generator(
generator, random_distribution);
caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
@ -336,10 +336,9 @@ void caffe_vRngBernoulli(const int n, Dtype* r, const double p) {
CHECK_GE(p, 0);
CHECK_LE(p, 1);
boost::bernoulli_distribution<double> random_distribution(p);
Caffe::random_generator_t &generator = Caffe::rng_stream();
boost::variate_generator<Caffe::random_generator_t,
boost::variate_generator<caffe::rng_t,
boost::bernoulli_distribution<double> > variate_generator(
generator, random_distribution);
caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();