From f6619a062d9e757a5f2c8eb43b1e11dc594321ee Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Tue, 17 Sep 2013 11:25:50 -0700 Subject: [PATCH] misc update --- src/caffeine/blob.cpp | 20 +-- src/caffeine/blob.hpp | 12 +- src/caffeine/common.hpp | 4 + src/caffeine/filler.hpp | 23 +++ src/caffeine/{ => layers}/dropout_layer.cu | 5 +- src/caffeine/layers/inner_product_layer.cu | 156 +++++++++++++++++++++ src/caffeine/{ => layers}/neuron_layer.cpp | 3 +- src/caffeine/{ => layers}/relu_layer.cu | 3 +- src/caffeine/proto/layer_param.proto | 6 +- src/caffeine/test/test_blob.cpp | 12 +- src/caffeine/test/test_neuron_layer.cpp | 76 +++++++++- src/caffeine/vision_layers.hpp | 26 +++- 12 files changed, 313 insertions(+), 33 deletions(-) rename src/caffeine/{ => layers}/dropout_layer.cu (97%) create mode 100644 src/caffeine/layers/inner_product_layer.cu rename src/caffeine/{ => layers}/neuron_layer.cpp (87%) rename src/caffeine/{ => layers}/relu_layer.cu (97%) diff --git a/src/caffeine/blob.cpp b/src/caffeine/blob.cpp index 80d4acf9..ab565515 100644 --- a/src/caffeine/blob.cpp +++ b/src/caffeine/blob.cpp @@ -7,25 +7,25 @@ namespace caffeine { template -void Blob::Reshape(const int num, const int channels, const int height, - const int width) { +void Blob::Reshape(const int num, const int height, + const int width, const int channels) { CHECK_GT(num, 0); - CHECK_GT(channels, 0); CHECK_GT(height, 0); CHECK_GT(width, 0); + CHECK_GT(channels, 0); num_ = num; - channels_ = channels; height_ = height; width_ = width; - count_ = num_ * channels_ * height_ * width_; + channels_ = channels; + count_ = num_ * height_ * width_ * channels_; data_.reset(new SyncedMemory(count_ * sizeof(Dtype))); diff_.reset(new SyncedMemory(count_ * sizeof(Dtype))); } template -Blob::Blob(const int num, const int channels, const int height, - const int width) { - Reshape(num, channels, height, width); +Blob::Blob(const int num, const int height, + const int width, const int channels) { + Reshape(num, height, width, channels); } template @@ -84,7 +84,7 @@ void Blob::Update() { template void Blob::FromProto(const BlobProto& proto) { - Reshape(proto.num(), proto.channels(), proto.height(), proto.width()); + Reshape(proto.num(), proto.height(), proto.width(), proto.channels()); // copy data Dtype* data_vec = mutable_cpu_data(); for (int i = 0; i < count_; ++i) { @@ -99,9 +99,9 @@ void Blob::FromProto(const BlobProto& proto) { template void Blob::ToProto(BlobProto* proto) { proto->set_num(num_); - proto->set_channels(channels_); proto->set_height(height_); proto->set_width(width_); + proto->set_channels(channels_); proto->clear_data(); proto->clear_diff(); const Dtype* data_vec = cpu_data(); diff --git a/src/caffeine/blob.hpp b/src/caffeine/blob.hpp index 4c0bf0d9..e3aad865 100644 --- a/src/caffeine/blob.hpp +++ b/src/caffeine/blob.hpp @@ -13,15 +13,15 @@ class Blob { Blob() : num_(0), channels_(0), height_(0), width_(0), count_(0), data_(), diff_() {}; - explicit Blob(const int num, const int channels, const int height, - const int width); + explicit Blob(const int num, const int height, + const int width, const int channels); virtual ~Blob() {}; - void Reshape(const int num, const int channels, const int height, - const int width); + void Reshape(const int num, const int height, + const int width, const int channels); inline int num() { return num_; } - inline int channels() { return channels_; } inline int height() { return height_; } inline int width() { return width_; } + inline int channels() { return channels_; } inline int count() {return count_; } const Dtype* cpu_data(); @@ -39,9 +39,9 @@ class Blob { shared_ptr data_; shared_ptr diff_; int num_; - int channels_; int height_; int width_; + int channels_; int count_; }; // class Blob diff --git a/src/caffeine/common.hpp b/src/caffeine/common.hpp index 080cb9a6..cc104340 100644 --- a/src/caffeine/common.hpp +++ b/src/caffeine/common.hpp @@ -15,6 +15,10 @@ #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS) #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK) +#define INSTANTIATE_CLASS(classname) \ + template class classname; \ + template class classname + namespace caffeine { // We will use the boost shared_ptr instead of the new C++11 one mainly diff --git a/src/caffeine/filler.hpp b/src/caffeine/filler.hpp index 880e6152..07f31da9 100644 --- a/src/caffeine/filler.hpp +++ b/src/caffeine/filler.hpp @@ -24,6 +24,11 @@ class Filler { FillerParameter filler_param_; }; // class Filler +template +class FillerFactory { + +}; + template class ConstantFiller : public Filler { public: @@ -90,6 +95,24 @@ class GaussianFiller : public Filler { }; }; +// A function to get a specific filler from the specification given in +// FillerParameter. Ideally this would be replaced by a factory pattern, +// but we will leave it this way for now. +template +Filler* GetFiller(const FillerParameter& param) { + const std::string& type = param.type(); + if (type == "constant") { + return new ConstantFiller(param); + } else if (type == "uniform") { + return new UniformFiller(param); + } else if (type == "gaussian") { + return new GaussianFiller(param); + } else { + CHECK(false) << "Unknown filler name: " << param.type(); + } + return (Filler*)(NULL); +} + } // namespace caffeine #endif // CAFFEINE_FILLER_HPP_ diff --git a/src/caffeine/dropout_layer.cu b/src/caffeine/layers/dropout_layer.cu similarity index 97% rename from src/caffeine/dropout_layer.cu rename to src/caffeine/layers/dropout_layer.cu index bfed41db..8dea15f9 100644 --- a/src/caffeine/dropout_layer.cu +++ b/src/caffeine/layers/dropout_layer.cu @@ -6,7 +6,6 @@ #include "caffeine/syncedmem.hpp" #include "caffeine/vision_layers.hpp" - using std::max; namespace caffeine { @@ -77,7 +76,6 @@ void DropoutLayer::Forward_gpu(const vector*>& bottom, Dtype* top_data = (*top)[0]->mutable_gpu_data(); const int count = bottom[0]->count(); if (Caffeine::phase() == Caffeine::TRAIN) { - // Create random numbers CURAND_CHECK(curandGenerate(Caffeine::curand_generator(), (unsigned int*)(rand_vec_->mutable_gpu_data()), count)); // set thresholds @@ -117,8 +115,7 @@ Dtype DropoutLayer::Backward_gpu(const vector*>& top, return Dtype(0); } -template class DropoutLayer; -template class DropoutLayer; +INSTANTIATE_CLASS(DropoutLayer); } // namespace caffeine diff --git a/src/caffeine/layers/inner_product_layer.cu b/src/caffeine/layers/inner_product_layer.cu new file mode 100644 index 00000000..fa40093a --- /dev/null +++ b/src/caffeine/layers/inner_product_layer.cu @@ -0,0 +1,156 @@ +#include +#include + +#include "caffeine/blob.hpp" +#include "caffeine/common.hpp" +#include "caffeine/filler.hpp" +#include "caffeine/layer.hpp" +#include "caffeine/vision_layers.hpp" + +namespace caffeine { + +template +void InnerProductLayer::SetUp(const vector*>& bottom, + vector*>* top) { + CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input."; + CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output."; + const int num_output = this->layer_param_.num_output(); + const bool gemm_last_dim = this->layer_param_.gemm_last_dim(); + biasterm_ = this->layer_param_.biasterm(); + // Figure out the dimensions + if (gemm_last_dim) { + M_ = bottom[0]->count() / bottom[0]->channels(); + K_ = bottom[0]->channels(); + N_ = num_output; + (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->height(), + bottom[0]->width(), num_output); + } else { + M_ = bottom[0]->num(); + K_ = bottom[0]->count() / bottom[0]->num(); + N_ = num_output; + (*top)[0]->Reshape(bottom[0]->num(), 1, 1, num_output); + } + if (biasterm_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Intialize the weight + Blob& weight = this->blobs_[0]; + weight.Reshape(1, 1, K_, N_); + // fill the weights + shared_ptr > weight_filler( + GetFiller(this->layer_param_.weight_filler())); + weight_filler->Fill(&weight); + // If necessary, intiialize and fill the bias term + if (biasterm_) { + Blob& bias = this->blobs_[1]; + bias.Reshape(1, 1, 1, N_); + shared_ptr > bias_filler( + GetFiller(this->layer_param_.bias_filler())); + bias_filler->Fill(&bias); + } +}; + +template +void InnerProductLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + const Dtype* weight = this->blobs_[0].cpu_data(); + const Dtype* bias = NULL; + if (biasterm_) { + bias = this->blobs_[1].cpu_data(); + } + switch(sizeof(Dtype)) { + case sizeof(float): + // matrix multiply + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_, + 1., (const float*)bottom_data, K_, (const float*)weight, N_, 0., + (float*)top_data, N_); + if (bias) { + // add bias + for (int i = 0; i < M_; ++i) { + cblas_saxpy(N_, 1., (const float*)bias, 1, + (float*)(top_data) + (N_ * i), 1); + } + } + case sizeof(double): + // matrix multiply + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_, + 1., (const double*)bottom_data, K_, (const double*)weight, N_, 0., + (double*)top_data, N_); + if (bias) { + // add bias + for (int i = 0; i < M_; ++i) { + cblas_daxpy(N_, 1., (const double*)bias, 1, + (double*)(top_data) + (N_ * i), 1); + } + } + default: + CHECK(false) << "Unknown data type."; + } +} + +template +Dtype InnerProductLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + CHECK(false); + return Dtype(0); +} + +template +__global__ void BroadcastCopy(const int total, const int vec_len, + const Dtype* in_vec, Dtype* out_matrix) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < total) { + int v_index = index % vec_len; + out_matrix[index] = in_vec[v_index]; + } +} + +template +void InnerProductLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0].gpu_data(); + const Dtype* bias = NULL; + Dtype alpha = 1., beta = 0.; + if (biasterm_) { + bias = this->blobs_[1].gpu_data(); + beta = 1.; + const int count = (*top)[0]->count(); + // we pre-copy the bias to the results, and then call gemm. + BroadcastCopy<<>>( + count, N_, bias, top_data); + } + switch(sizeof(Dtype)) { + case sizeof(float): + // matrix multiply: since cublas uses Fortran major, we actually do + // C' = B' A' + CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N, + CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_, + (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_)); + case sizeof(double): + // matrix multiply + CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N, + CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_, + (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_)); + default: + CHECK(false) << "Unknown data type."; + } +} + +template +Dtype InnerProductLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + CHECK(false); + return Dtype(0.); +} + +INSTANTIATE_CLASS(InnerProductLayer); + +} // namespace caffeine diff --git a/src/caffeine/neuron_layer.cpp b/src/caffeine/layers/neuron_layer.cpp similarity index 87% rename from src/caffeine/neuron_layer.cpp rename to src/caffeine/layers/neuron_layer.cpp index 050c6906..4cac4342 100644 --- a/src/caffeine/neuron_layer.cpp +++ b/src/caffeine/layers/neuron_layer.cpp @@ -12,7 +12,6 @@ void NeuronLayer::SetUp(const vector*>& bottom, bottom[0]->height(), bottom[0]->width()); }; -template class NeuronLayer; -template class NeuronLayer; +INSTANTIATE_CLASS(NeuronLayer); } // namespace caffeine diff --git a/src/caffeine/relu_layer.cu b/src/caffeine/layers/relu_layer.cu similarity index 97% rename from src/caffeine/relu_layer.cu rename to src/caffeine/layers/relu_layer.cu index fb95b043..12a9b6c9 100644 --- a/src/caffeine/relu_layer.cu +++ b/src/caffeine/layers/relu_layer.cu @@ -75,8 +75,7 @@ Dtype ReLULayer::Backward_gpu(const vector*>& top, return Dtype(0); } -template class ReLULayer; -template class ReLULayer; +INSTANTIATE_CLASS(ReLULayer); } // namespace caffeine diff --git a/src/caffeine/proto/layer_param.proto b/src/caffeine/proto/layer_param.proto index 7bb37089..58dbe932 100644 --- a/src/caffeine/proto/layer_param.proto +++ b/src/caffeine/proto/layer_param.proto @@ -19,10 +19,14 @@ message LayerParameter { optional float alpha = 13 [default = 1.]; // for local response norm optional float beta = 14 [default = 0.75]; // for local response norm + + // for innerproduct: if true, carry out inner product computation on the + // last dim only + optional bool gemm_last_dim = 15 [ default = false]; } message FillerParameter { - required string type = 1; + required string type = 1 [default = 'constant']; optional float value = 2 [default = 0]; // the value in constant filler optional float min = 3 [default = 0]; // the min value in uniform filler optional float max = 4 [default = 1]; // the max value in uniform filler diff --git a/src/caffeine/test/test_blob.cpp b/src/caffeine/test/test_blob.cpp index f5ebc634..ba564229 100644 --- a/src/caffeine/test/test_blob.cpp +++ b/src/caffeine/test/test_blob.cpp @@ -25,9 +25,9 @@ TYPED_TEST(BlobSimpleTest, TestInitialization) { EXPECT_TRUE(this->blob_); EXPECT_TRUE(this->blob_preshaped_); EXPECT_EQ(this->blob_preshaped_->num(), 2); - EXPECT_EQ(this->blob_preshaped_->channels(), 3); - EXPECT_EQ(this->blob_preshaped_->height(), 4); - EXPECT_EQ(this->blob_preshaped_->width(), 5); + EXPECT_EQ(this->blob_preshaped_->height(), 3); + EXPECT_EQ(this->blob_preshaped_->width(), 4); + EXPECT_EQ(this->blob_preshaped_->channels(), 5); EXPECT_EQ(this->blob_preshaped_->count(), 120); EXPECT_EQ(this->blob_->num(), 0); EXPECT_EQ(this->blob_->channels(), 0); @@ -46,9 +46,9 @@ TYPED_TEST(BlobSimpleTest, TestPointers) { TYPED_TEST(BlobSimpleTest, TestReshape) { this->blob_->Reshape(2, 3, 4, 5); EXPECT_EQ(this->blob_->num(), 2); - EXPECT_EQ(this->blob_->channels(), 3); - EXPECT_EQ(this->blob_->height(), 4); - EXPECT_EQ(this->blob_->width(), 5); + EXPECT_EQ(this->blob_->height(), 3); + EXPECT_EQ(this->blob_->width(), 4); + EXPECT_EQ(this->blob_->channels(), 5); EXPECT_EQ(this->blob_->count(), 120); } diff --git a/src/caffeine/test/test_neuron_layer.cpp b/src/caffeine/test/test_neuron_layer.cpp index 92a50a59..d64a0145 100644 --- a/src/caffeine/test/test_neuron_layer.cpp +++ b/src/caffeine/test/test_neuron_layer.cpp @@ -14,7 +14,7 @@ class NeuronLayerTest : public ::testing::Test { protected: NeuronLayerTest() : blob_bottom_(new Blob(2, 3, 4, 5)), - blob_top_(new Blob(2, 3, 4, 5)) { + blob_top_(new Blob()) { // fill the values FillerParameter filler_param; GaussianFiller filler(filler_param); @@ -36,6 +36,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUCPU) { LayerParameter layer_param; Caffeine::set_mode(Caffeine::CPU); ReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); // Now, check values const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); @@ -50,6 +51,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUGPU) { LayerParameter layer_param; Caffeine::set_mode(Caffeine::GPU); ReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); // Now, check values const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); @@ -60,4 +62,76 @@ TYPED_TEST(NeuronLayerTest, TestReLUGPU) { } } +TYPED_TEST(NeuronLayerTest, TestDropoutCPU) { + LayerParameter layer_param; + Caffeine::set_mode(Caffeine::CPU); + Caffeine::set_phase(Caffeine::TRAIN); + DropoutLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + float scale = 1. / (1. - layer_param.dropout_ratio()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] != 0) { + EXPECT_EQ(top_data[i], bottom_data[i] * scale); + } + } +} + +TYPED_TEST(NeuronLayerTest, TestDropoutCPUTestPhase) { + LayerParameter layer_param; + Caffeine::set_mode(Caffeine::CPU); + Caffeine::set_phase(Caffeine::TEST); + DropoutLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + float scale = 1. / (1. - layer_param.dropout_ratio()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] != 0) { + EXPECT_EQ(top_data[i], bottom_data[i]); + } + } +} + +TYPED_TEST(NeuronLayerTest, TestDropoutGPU) { + LayerParameter layer_param; + Caffeine::set_mode(Caffeine::GPU); + Caffeine::set_phase(Caffeine::TRAIN); + DropoutLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + float scale = 1. / (1. - layer_param.dropout_ratio()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] != 0) { + EXPECT_EQ(top_data[i], bottom_data[i] * scale); + } + } +} + +TYPED_TEST(NeuronLayerTest, TestDropoutGPUTestPhase) { + LayerParameter layer_param; + Caffeine::set_mode(Caffeine::GPU); + Caffeine::set_phase(Caffeine::TEST); + DropoutLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + float scale = 1. / (1. - layer_param.dropout_ratio()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] != 0) { + EXPECT_EQ(top_data[i], bottom_data[i]); + } + } +} + } diff --git a/src/caffeine/vision_layers.hpp b/src/caffeine/vision_layers.hpp index 432da7eb..8cf361c5 100644 --- a/src/caffeine/vision_layers.hpp +++ b/src/caffeine/vision_layers.hpp @@ -5,6 +5,8 @@ namespace caffeine { +// The neuron layer is a specific type of layers that just works on single +// celements. template class NeuronLayer : public Layer { public: @@ -14,6 +16,7 @@ class NeuronLayer : public Layer { vector*>* top); }; + template class ReLULayer : public NeuronLayer { public: @@ -31,6 +34,7 @@ class ReLULayer : public NeuronLayer { const bool propagate_down, vector*>* bottom); }; + template class DropoutLayer : public NeuronLayer { public: @@ -55,8 +59,28 @@ class DropoutLayer : public NeuronLayer { }; +template +class InnerProductLayer : public Layer { + public: + explicit InnerProductLayer(const LayerParameter& param) + : Layer(param) {}; + virtual void SetUp(const vector*>& bottom, + vector*>* top); + protected: + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); - + virtual Dtype Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + int M_; + int K_; + int N_; + bool biasterm_; +}; } // namespace caffeine