This commit is contained in:
Yangqing Jia 2013-09-17 11:25:50 -07:00
Родитель 9a10ea2500
Коммит f6619a062d
12 изменённых файлов: 313 добавлений и 33 удалений

Просмотреть файл

@ -7,25 +7,25 @@
namespace caffeine {
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
void Blob<Dtype>::Reshape(const int num, const int height,
const int width, const int channels) {
CHECK_GT(num, 0);
CHECK_GT(channels, 0);
CHECK_GT(height, 0);
CHECK_GT(width, 0);
CHECK_GT(channels, 0);
num_ = num;
channels_ = channels;
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
channels_ = channels;
count_ = num_ * height_ * width_ * channels_;
data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
}
template <typename Dtype>
Blob<Dtype>::Blob(const int num, const int channels, const int height,
const int width) {
Reshape(num, channels, height, width);
Blob<Dtype>::Blob(const int num, const int height,
const int width, const int channels) {
Reshape(num, height, width, channels);
}
template <typename Dtype>
@ -84,7 +84,7 @@ void Blob<Dtype>::Update() {
template <typename Dtype>
void Blob<Dtype>::FromProto(const BlobProto& proto) {
Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
Reshape(proto.num(), proto.height(), proto.width(), proto.channels());
// copy data
Dtype* data_vec = mutable_cpu_data();
for (int i = 0; i < count_; ++i) {
@ -99,9 +99,9 @@ void Blob<Dtype>::FromProto(const BlobProto& proto) {
template <typename Dtype>
void Blob<Dtype>::ToProto(BlobProto* proto) {
proto->set_num(num_);
proto->set_channels(channels_);
proto->set_height(height_);
proto->set_width(width_);
proto->set_channels(channels_);
proto->clear_data();
proto->clear_diff();
const Dtype* data_vec = cpu_data();

Просмотреть файл

@ -13,15 +13,15 @@ class Blob {
Blob()
: num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
diff_() {};
explicit Blob(const int num, const int channels, const int height,
const int width);
explicit Blob(const int num, const int height,
const int width, const int channels);
virtual ~Blob() {};
void Reshape(const int num, const int channels, const int height,
const int width);
void Reshape(const int num, const int height,
const int width, const int channels);
inline int num() { return num_; }
inline int channels() { return channels_; }
inline int height() { return height_; }
inline int width() { return width_; }
inline int channels() { return channels_; }
inline int count() {return count_; }
const Dtype* cpu_data();
@ -39,9 +39,9 @@ class Blob {
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
int num_;
int channels_;
int height_;
int width_;
int channels_;
int count_;
}; // class Blob

Просмотреть файл

@ -15,6 +15,10 @@
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
#define INSTANTIATE_CLASS(classname) \
template class classname<float>; \
template class classname<double>
namespace caffeine {
// We will use the boost shared_ptr instead of the new C++11 one mainly

Просмотреть файл

@ -24,6 +24,11 @@ class Filler {
FillerParameter filler_param_;
}; // class Filler
template <typename Dtype>
class FillerFactory {
};
template <typename Dtype>
class ConstantFiller : public Filler<Dtype> {
public:
@ -90,6 +95,24 @@ class GaussianFiller : public Filler<Dtype> {
};
};
// A function to get a specific filler from the specification given in
// FillerParameter. Ideally this would be replaced by a factory pattern,
// but we will leave it this way for now.
template <typename Dtype>
Filler<Dtype>* GetFiller(const FillerParameter& param) {
const std::string& type = param.type();
if (type == "constant") {
return new ConstantFiller<Dtype>(param);
} else if (type == "uniform") {
return new UniformFiller<Dtype>(param);
} else if (type == "gaussian") {
return new GaussianFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
return (Filler<Dtype>*)(NULL);
}
} // namespace caffeine
#endif // CAFFEINE_FILLER_HPP_

Просмотреть файл

@ -6,7 +6,6 @@
#include "caffeine/syncedmem.hpp"
#include "caffeine/vision_layers.hpp"
using std::max;
namespace caffeine {
@ -77,7 +76,6 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
if (Caffeine::phase() == Caffeine::TRAIN) {
// Create random numbers
CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
(unsigned int*)(rand_vec_->mutable_gpu_data()), count));
// set thresholds
@ -117,8 +115,7 @@ Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
return Dtype(0);
}
template class DropoutLayer<float>;
template class DropoutLayer<double>;
INSTANTIATE_CLASS(DropoutLayer);
} // namespace caffeine

Просмотреть файл

@ -0,0 +1,156 @@
#include <mkl.h>
#include <cublas_v2.h>
#include "caffeine/blob.hpp"
#include "caffeine/common.hpp"
#include "caffeine/filler.hpp"
#include "caffeine/layer.hpp"
#include "caffeine/vision_layers.hpp"
namespace caffeine {
template <typename Dtype>
void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
const int num_output = this->layer_param_.num_output();
const bool gemm_last_dim = this->layer_param_.gemm_last_dim();
biasterm_ = this->layer_param_.biasterm();
// Figure out the dimensions
if (gemm_last_dim) {
M_ = bottom[0]->count() / bottom[0]->channels();
K_ = bottom[0]->channels();
N_ = num_output;
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->height(),
bottom[0]->width(), num_output);
} else {
M_ = bottom[0]->num();
K_ = bottom[0]->count() / bottom[0]->num();
N_ = num_output;
(*top)[0]->Reshape(bottom[0]->num(), 1, 1, num_output);
}
if (biasterm_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Intialize the weight
Blob<Dtype>& weight = this->blobs_[0];
weight.Reshape(1, 1, K_, N_);
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(
GetFiller<Dtype>(this->layer_param_.weight_filler()));
weight_filler->Fill(&weight);
// If necessary, intiialize and fill the bias term
if (biasterm_) {
Blob<Dtype>& bias = this->blobs_[1];
bias.Reshape(1, 1, 1, N_);
shared_ptr<Filler<Dtype> > bias_filler(
GetFiller<Dtype>(this->layer_param_.bias_filler()));
bias_filler->Fill(&bias);
}
};
template <typename Dtype>
void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0].cpu_data();
const Dtype* bias = NULL;
if (biasterm_) {
bias = this->blobs_[1].cpu_data();
}
switch(sizeof(Dtype)) {
case sizeof(float):
// matrix multiply
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
1., (const float*)bottom_data, K_, (const float*)weight, N_, 0.,
(float*)top_data, N_);
if (bias) {
// add bias
for (int i = 0; i < M_; ++i) {
cblas_saxpy(N_, 1., (const float*)bias, 1,
(float*)(top_data) + (N_ * i), 1);
}
}
case sizeof(double):
// matrix multiply
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
1., (const double*)bottom_data, K_, (const double*)weight, N_, 0.,
(double*)top_data, N_);
if (bias) {
// add bias
for (int i = 0; i < M_; ++i) {
cblas_daxpy(N_, 1., (const double*)bias, 1,
(double*)(top_data) + (N_ * i), 1);
}
}
default:
CHECK(false) << "Unknown data type.";
}
}
template <typename Dtype>
Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
CHECK(false);
return Dtype(0);
}
template <typename Dtype>
__global__ void BroadcastCopy(const int total, const int vec_len,
const Dtype* in_vec, Dtype* out_matrix) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < total) {
int v_index = index % vec_len;
out_matrix[index] = in_vec[v_index];
}
}
template <typename Dtype>
void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0].gpu_data();
const Dtype* bias = NULL;
Dtype alpha = 1., beta = 0.;
if (biasterm_) {
bias = this->blobs_[1].gpu_data();
beta = 1.;
const int count = (*top)[0]->count();
// we pre-copy the bias to the results, and then call gemm.
BroadcastCopy<<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, N_, bias, top_data);
}
switch(sizeof(Dtype)) {
case sizeof(float):
// matrix multiply: since cublas uses Fortran major, we actually do
// C' = B' A'
CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_,
(const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_));
case sizeof(double):
// matrix multiply
CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_,
(const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_));
default:
CHECK(false) << "Unknown data type.";
}
}
template <typename Dtype>
Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
CHECK(false);
return Dtype(0.);
}
INSTANTIATE_CLASS(InnerProductLayer);
} // namespace caffeine

Просмотреть файл

@ -12,7 +12,6 @@ void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
bottom[0]->height(), bottom[0]->width());
};
template class NeuronLayer<float>;
template class NeuronLayer<double>;
INSTANTIATE_CLASS(NeuronLayer);
} // namespace caffeine

Просмотреть файл

@ -75,8 +75,7 @@ Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
return Dtype(0);
}
template class ReLULayer<float>;
template class ReLULayer<double>;
INSTANTIATE_CLASS(ReLULayer);
} // namespace caffeine

Просмотреть файл

@ -19,10 +19,14 @@ message LayerParameter {
optional float alpha = 13 [default = 1.]; // for local response norm
optional float beta = 14 [default = 0.75]; // for local response norm
// for innerproduct: if true, carry out inner product computation on the
// last dim only
optional bool gemm_last_dim = 15 [ default = false];
}
message FillerParameter {
required string type = 1;
required string type = 1 [default = 'constant'];
optional float value = 2 [default = 0]; // the value in constant filler
optional float min = 3 [default = 0]; // the min value in uniform filler
optional float max = 4 [default = 1]; // the max value in uniform filler

Просмотреть файл

@ -25,9 +25,9 @@ TYPED_TEST(BlobSimpleTest, TestInitialization) {
EXPECT_TRUE(this->blob_);
EXPECT_TRUE(this->blob_preshaped_);
EXPECT_EQ(this->blob_preshaped_->num(), 2);
EXPECT_EQ(this->blob_preshaped_->channels(), 3);
EXPECT_EQ(this->blob_preshaped_->height(), 4);
EXPECT_EQ(this->blob_preshaped_->width(), 5);
EXPECT_EQ(this->blob_preshaped_->height(), 3);
EXPECT_EQ(this->blob_preshaped_->width(), 4);
EXPECT_EQ(this->blob_preshaped_->channels(), 5);
EXPECT_EQ(this->blob_preshaped_->count(), 120);
EXPECT_EQ(this->blob_->num(), 0);
EXPECT_EQ(this->blob_->channels(), 0);
@ -46,9 +46,9 @@ TYPED_TEST(BlobSimpleTest, TestPointers) {
TYPED_TEST(BlobSimpleTest, TestReshape) {
this->blob_->Reshape(2, 3, 4, 5);
EXPECT_EQ(this->blob_->num(), 2);
EXPECT_EQ(this->blob_->channels(), 3);
EXPECT_EQ(this->blob_->height(), 4);
EXPECT_EQ(this->blob_->width(), 5);
EXPECT_EQ(this->blob_->height(), 3);
EXPECT_EQ(this->blob_->width(), 4);
EXPECT_EQ(this->blob_->channels(), 5);
EXPECT_EQ(this->blob_->count(), 120);
}

Просмотреть файл

@ -14,7 +14,7 @@ class NeuronLayerTest : public ::testing::Test {
protected:
NeuronLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
blob_top_(new Blob<Dtype>(2, 3, 4, 5)) {
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
@ -36,6 +36,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUCPU) {
LayerParameter layer_param;
Caffeine::set_mode(Caffeine::CPU);
ReLULayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
@ -50,6 +51,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUGPU) {
LayerParameter layer_param;
Caffeine::set_mode(Caffeine::GPU);
ReLULayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
@ -60,4 +62,76 @@ TYPED_TEST(NeuronLayerTest, TestReLUGPU) {
}
}
TYPED_TEST(NeuronLayerTest, TestDropoutCPU) {
LayerParameter layer_param;
Caffeine::set_mode(Caffeine::CPU);
Caffeine::set_phase(Caffeine::TRAIN);
DropoutLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
float scale = 1. / (1. - layer_param.dropout_ratio());
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
if (top_data[i] != 0) {
EXPECT_EQ(top_data[i], bottom_data[i] * scale);
}
}
}
TYPED_TEST(NeuronLayerTest, TestDropoutCPUTestPhase) {
LayerParameter layer_param;
Caffeine::set_mode(Caffeine::CPU);
Caffeine::set_phase(Caffeine::TEST);
DropoutLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
float scale = 1. / (1. - layer_param.dropout_ratio());
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
if (top_data[i] != 0) {
EXPECT_EQ(top_data[i], bottom_data[i]);
}
}
}
TYPED_TEST(NeuronLayerTest, TestDropoutGPU) {
LayerParameter layer_param;
Caffeine::set_mode(Caffeine::GPU);
Caffeine::set_phase(Caffeine::TRAIN);
DropoutLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
float scale = 1. / (1. - layer_param.dropout_ratio());
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
if (top_data[i] != 0) {
EXPECT_EQ(top_data[i], bottom_data[i] * scale);
}
}
}
TYPED_TEST(NeuronLayerTest, TestDropoutGPUTestPhase) {
LayerParameter layer_param;
Caffeine::set_mode(Caffeine::GPU);
Caffeine::set_phase(Caffeine::TEST);
DropoutLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
float scale = 1. / (1. - layer_param.dropout_ratio());
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
if (top_data[i] != 0) {
EXPECT_EQ(top_data[i], bottom_data[i]);
}
}
}
}

Просмотреть файл

@ -5,6 +5,8 @@
namespace caffeine {
// The neuron layer is a specific type of layers that just works on single
// celements.
template <typename Dtype>
class NeuronLayer : public Layer<Dtype> {
public:
@ -14,6 +16,7 @@ class NeuronLayer : public Layer<Dtype> {
vector<Blob<Dtype>*>* top);
};
template <typename Dtype>
class ReLULayer : public NeuronLayer<Dtype> {
public:
@ -31,6 +34,7 @@ class ReLULayer : public NeuronLayer<Dtype> {
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};
template <typename Dtype>
class DropoutLayer : public NeuronLayer<Dtype> {
public:
@ -55,8 +59,28 @@ class DropoutLayer : public NeuronLayer<Dtype> {
};
template <typename Dtype>
class InnerProductLayer : public Layer<Dtype> {
public:
explicit InnerProductLayer(const LayerParameter& param)
: Layer<Dtype>(param) {};
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
int M_;
int K_;
int N_;
bool biasterm_;
};
} // namespace caffeine