Merge branch 'master' of github.com:Yangqing/caffe

Conflicts:
	Makefile
This commit is contained in:
Yangqing Jia 2013-11-11 17:08:27 -08:00
Родитель 76bf486e3d c8e7cce731
Коммит 652d744360
9 изменённых файлов: 80 добавлений и 5 удалений

Просмотреть файл

@ -45,7 +45,8 @@ INCLUDE_DIRS := ./src ./include /usr/local/include $(CUDA_INCLUDE_DIR) \
$(MKL_INCLUDE_DIR)
LIBRARY_DIRS := /usr/lib /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system
glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system \
opencv_imgproc
WARNINGS := -Wall
COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))

Просмотреть файл

@ -16,6 +16,7 @@ class Solver {
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
virtual ~Solver() {}
inline Net<Dtype>* net() { return net_.get(); }
protected:
// PreSolve is run before any solving iteration starts, allowing one to
@ -36,6 +37,7 @@ class Solver {
// function that restores the state from a SolverState protocol buffer.
void Restore(const char* resume_file);
virtual void RestoreSolverState(const SolverState& state) = 0;
SolverParameter param_;
int iter_;
shared_ptr<Net<Dtype> > net_;

Просмотреть файл

@ -45,7 +45,7 @@ bool ReadImageToDatum(const string& filename, const int label,
inline bool ReadImageToDatum(const string& filename, const int label,
Datum* datum) {
ReadImageToDatum(filename, label, 0, 0, datum);
return ReadImageToDatum(filename, label, 0, 0, datum);
}
} // namespace caffe

Просмотреть файл

@ -345,6 +345,29 @@ class MultinomialLogisticLossLayer : public Layer<Dtype> {
// const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};
template <typename Dtype>
class InfogainLossLayer : public Layer<Dtype> {
public:
explicit InfogainLossLayer(const LayerParameter& param)
: Layer<Dtype>(param), infogain_() {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
protected:
// The loss layer will do nothing during forward - all computation are
// carried out in the backward pass.
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) { return; }
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) { return; }
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
// virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
// const bool propagate_down, vector<Blob<Dtype>*>* bottom);
Blob<Dtype> infogain_;
};
// SoftmaxWithLossLayer is a layer that implements softmax and then computes
// the loss - it is preferred over softmax + multinomiallogisticloss in the

Просмотреть файл

@ -33,6 +33,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new EuclideanLossLayer<Dtype>(param);
} else if (type == "im2col") {
return new Im2colLayer<Dtype>(param);
} else if (type == "infogain_loss") {
return new InfogainLossLayer<Dtype>(param);
} else if (type == "innerproduct") {
return new InnerProductLayer<Dtype>(param);
} else if (type == "lrn") {

Просмотреть файл

@ -6,6 +6,7 @@
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/io.hpp"
using std::max;
@ -17,7 +18,7 @@ template <typename Dtype>
void MultinomialLogisticLossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
CHECK_EQ(bottom[1]->channels(), 1);
@ -49,6 +50,49 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>
// TODO: implement the GPU version for multinomial loss
template <typename Dtype>
void InfogainLossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
BlobProto blob_proto;
ReadProtoFromBinaryFile(this->layer_param_.source(), &blob_proto);
infogain_.FromProto(blob_proto);
CHECK_EQ(infogain_.num(), 1);
CHECK_EQ(infogain_.channels(), 1);
CHECK_EQ(infogain_.height(), infogain_.width());
};
template <typename Dtype>
Dtype InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* bottom_label = (*bottom)[1]->cpu_data();
const Dtype* infogain_mat = infogain_.cpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
int num = (*bottom)[0]->num();
int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
CHECK_EQ(infogain_.height(), dim);
Dtype loss = 0;
for (int i = 0; i < num; ++i) {
int label = static_cast<int>(bottom_label[i]);
for (int j = 0; j < dim; ++j) {
Dtype prob = max(bottom_data[i * dim + j], kLOG_THRESHOLD);
loss -= infogain_mat[label * dim + j] * log(prob);
bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
}
}
return loss / num;
}
template <typename Dtype>
void EuclideanLossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
@ -122,6 +166,7 @@ void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
INSTANTIATE_CLASS(InfogainLossLayer);
INSTANTIATE_CLASS(EuclideanLossLayer);
INSTANTIATE_CLASS(AccuracyLayer);

Просмотреть файл

@ -120,7 +120,8 @@ __global__ void StoPoolForwardTest(const int nthreads,
int hend = min(hstart + ksize, height);
int wstart = pw * stride;
int wend = min(wstart + ksize, width);
Dtype cumsum = 0.;
// We set cumsum to be 0 to avoid divide-by-zero problems
Dtype cumsum = FLT_MIN;
Dtype cumvalues = 0.;
bottom_data += (n * channels + c) * height * width;
// First pass: get sum

Просмотреть файл

@ -287,7 +287,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
DLOG(INFO) << "Ignoring source layer " << source_layer_name;
continue;
}
DLOG(INFO) << "Loading source layer " << source_layer_name;
LOG(INFO) << "Copying source layer " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_[target_layer_id]->blobs();
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())

Просмотреть файл

@ -7,6 +7,7 @@
#include <google/protobuf/io/coded_stream.h>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <algorithm>
#include <string>