From 453fcf909522937abf1bd4e44efa4932d5d4aca6 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Fri, 21 Mar 2014 14:58:11 -0700 Subject: [PATCH] clean up residual mkl comments and code The FIXMEs about RNG were addressed by caffe_nextafter for uniform distributions and the normal distribution concern is surely a typo in the boost documentation, since the normal pdf is correctly stated elsewhere in the documentation. --- include/caffe/common.hpp | 16 ++++------------ include/caffe/filler.hpp | 1 - src/caffe/common.cpp | 14 +------------- src/caffe/layers/dropout_layer.cpp | 2 -- src/caffe/layers/inner_product_layer.cpp | 3 --- src/caffe/layers/inner_product_layer.cu | 2 -- src/caffe/test/test_common.cpp | 11 ----------- src/caffe/test/test_util_blas.cpp | 1 - src/caffe/util/math_functions.cpp | 20 +++----------------- 9 files changed, 8 insertions(+), 62 deletions(-) diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 9621b261..2ffc93f2 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -8,16 +8,13 @@ #include #include #include -// cuda driver types -#include +#include // cuda driver types #include -//#include // various checks for different function calls. #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess) #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS) #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS) -#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK) #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ @@ -46,7 +43,6 @@ private:\ // is executed we will see a fatal log. #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet" - namespace caffe { // We will use the boost shared_ptr instead of the new C++11 one mainly @@ -62,7 +58,6 @@ using boost::shared_ptr; #endif - inline int CAFFE_GET_BLOCKS(const int N) { return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; } @@ -90,11 +85,9 @@ class Caffe { return Get().curand_generator_; } - // Returns the MKL random stream. - //inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; } - + // boost RNG typedef boost::mt19937 random_generator_t; - inline static random_generator_t &vsl_stream() { return Get().random_generator_; } + inline static random_generator_t &rng_stream() { return Get().random_generator_; } // Returns the mode: running on CPU or GPU. inline static Brew mode() { return Get().mode_; } @@ -108,7 +101,7 @@ class Caffe { inline static void set_mode(Brew mode) { Get().mode_ = mode; } // Sets the phase. inline static void set_phase(Phase phase) { Get().phase_ = phase; } - // Sets the random seed of both MKL and curand + // Sets the random seed of both boost and curand static void set_random_seed(const unsigned int seed); // Sets the device. Since we have cublas and curand stuff, set device also // requires us to reset those values. @@ -119,7 +112,6 @@ class Caffe { protected: cublasHandle_t cublas_handle_; curandGenerator_t curand_generator_; - //VSLStreamStatePtr vsl_stream_; random_generator_t random_generator_; Brew mode_; diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index d0b5baa0..7c100224 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -7,7 +7,6 @@ #ifndef CAFFE_FILLER_HPP #define CAFFE_FILLER_HPP -//#include #include #include "caffe/common.hpp" diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 95a5e93a..29501bb6 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -22,7 +22,6 @@ int64_t cluster_seedgen(void) { Caffe::Caffe() : mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL), curand_generator_(NULL), - //vsl_stream_(NULL) random_generator_() { // Try to create a cublas handler, and report an error if failed (but we will @@ -37,13 +36,6 @@ Caffe::Caffe() != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "Cannot create Curand generator. Curand won't be available."; } - - // Try to create a vsl stream. This should almost always work, but we will - // check it anyway. - //if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) { - // LOG(ERROR) << "Cannot create vsl stream. VSL random number generator " - // << "won't be available."; - //} } Caffe::~Caffe() { @@ -51,7 +43,6 @@ Caffe::~Caffe() { if (curand_generator_) { CURAND_CHECK(curandDestroyGenerator(curand_generator_)); } - //if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_)); } void Caffe::set_random_seed(const unsigned int seed) { @@ -67,11 +58,8 @@ void Caffe::set_random_seed(const unsigned int seed) { } else { LOG(ERROR) << "Curand not available. Skipping setting the curand seed."; } - // VSL seed - //VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_))); - //VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed)); + // RNG seed Get().random_generator_ = random_generator_t(seed); - } void Caffe::SetDevice(const int device_id) { diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index bfb854bc..f07547ad 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -32,8 +32,6 @@ Dtype DropoutLayer::Forward_cpu(const vector*>& bottom, const int count = bottom[0]->count(); if (Caffe::phase() == Caffe::TRAIN) { // Create random numbers - //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - // count, mask, 1. - threshold_); caffe_vRngBernoulli(count, mask, 1. - threshold_); for (int i = 0; i < count; ++i) { top_data[i] = bottom_data[i] * mask[i] * scale_; diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index a00e2f21..6ea228fe 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -1,8 +1,5 @@ // Copyright 2013 Yangqing Jia - -//#include - #include #include "caffe/blob.hpp" diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index 0d397dc0..37463b5a 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -1,7 +1,5 @@ // Copyright 2013 Yangqing Jia - -//#include #include #include diff --git a/src/caffe/test/test_common.cpp b/src/caffe/test/test_common.cpp index f5e3fe47..3ce15bba 100644 --- a/src/caffe/test/test_common.cpp +++ b/src/caffe/test/test_common.cpp @@ -19,11 +19,6 @@ TEST_F(CommonTest, TestCublasHandler) { EXPECT_TRUE(Caffe::cublas_handle()); } -TEST_F(CommonTest, TestVslStream) { - //EXPECT_TRUE(Caffe::vsl_stream()); - EXPECT_TRUE(true); -} - TEST_F(CommonTest, TestBrewMode) { Caffe::set_mode(Caffe::CPU); EXPECT_EQ(Caffe::mode(), Caffe::CPU); @@ -41,13 +36,9 @@ TEST_F(CommonTest, TestRandSeedCPU) { SyncedMemory data_a(10 * sizeof(int)); SyncedMemory data_b(10 * sizeof(int)); Caffe::set_random_seed(1701); - //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - // 10, (int*)data_a.mutable_cpu_data(), 0.5); caffe_vRngBernoulli(10, reinterpret_cast(data_a.mutable_cpu_data()), 0.5); Caffe::set_random_seed(1701); - //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - // 10, (int*)data_b.mutable_cpu_data(), 0.5); caffe_vRngBernoulli(10, reinterpret_cast(data_b.mutable_cpu_data()), 0.5); for (int i = 0; i < 10; ++i) { @@ -56,7 +47,6 @@ TEST_F(CommonTest, TestRandSeedCPU) { } } - TEST_F(CommonTest, TestRandSeedGPU) { SyncedMemory data_a(10 * sizeof(unsigned int)); SyncedMemory data_b(10 * sizeof(unsigned int)); @@ -72,5 +62,4 @@ TEST_F(CommonTest, TestRandSeedGPU) { } } - } // namespace caffe diff --git a/src/caffe/test/test_util_blas.cpp b/src/caffe/test/test_util_blas.cpp index 4ac49555..57f4eafc 100644 --- a/src/caffe/test/test_util_blas.cpp +++ b/src/caffe/test/test_util_blas.cpp @@ -3,7 +3,6 @@ #include #include "cuda_runtime.h" -//#include "mkl.h" #include "cublas_v2.h" #include "gtest/gtest.h" diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index fb2b1127..d68c05c3 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -2,7 +2,6 @@ // Copyright 2014 kloudkl@github #include -//#include #include #include @@ -284,14 +283,10 @@ void caffe_vRngUniform(const int n, Dtype* r, CHECK_GE(n, 0); CHECK(r); CHECK_LE(a, b); - //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - // n, r, a, b)); - // FIXME check if boundaries are handled in the same way ? - // Fixed by caffe_nextafter boost::uniform_real random_distribution( a, caffe_nextafter(b)); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + Caffe::random_generator_t &generator = Caffe::rng_stream(); boost::variate_generator > variate_generator( generator, random_distribution); @@ -314,17 +309,8 @@ void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, CHECK_GE(n, 0); CHECK(r); CHECK_GT(sigma, 0); - //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, -// Caffe::vsl_stream(), n, r, a, sigma)); - - // FIXME check if parameters are handled in the same way ? - // http://www.boost.org/doc/libs/1_55_0/doc/html/boost/random/normal_distribution.html - // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-63196F25-5013-4038-8BCD-2613C4EF3DE4.htm - // The above two documents show that the probability density functions are different. - // But the unit tests still pass. Maybe their codes are the same or - // the tests are irrelevant to the random numbers. boost::normal_distribution random_distribution(a, sigma); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + Caffe::random_generator_t &generator = Caffe::rng_stream(); boost::variate_generator > variate_generator( generator, random_distribution); @@ -349,7 +335,7 @@ void caffe_vRngBernoulli(const int n, Dtype* r, const double p) { CHECK_GE(p, 0); CHECK_LE(p, 1); boost::bernoulli_distribution random_distribution(p); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + Caffe::random_generator_t &generator = Caffe::rng_stream(); boost::variate_generator > variate_generator( generator, random_distribution);