From 749ba87e48049a7676f4dca8052d17fff6c485af Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Wed, 23 Oct 2013 11:14:50 -0700 Subject: [PATCH] need backward computation, and train_net resume point. Not debugged. --- examples/train_net.cpp | 9 ++++++-- include/caffe/net.hpp | 6 +++++ src/caffe/net.cpp | 51 ++++++++++++++++++++++++++++++++---------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/examples/train_net.cpp b/examples/train_net.cpp index 06ca1213..b4181d64 100644 --- a/examples/train_net.cpp +++ b/examples/train_net.cpp @@ -3,7 +3,7 @@ // This is a simple script that allows one to quickly train a network whose // parameters are specified by text format protocol buffers. // Usage: -// train_net net_proto_file solver_proto_file +// train_net net_proto_file solver_proto_file [resume_point_file] #include @@ -28,7 +28,12 @@ int main(int argc, char** argv) { LOG(ERROR) << "Starting Optimization"; SGDSolver solver(solver_param); - solver.Solve(&caffe_net); + if (argc == 4) { + LOG(ERROR) << "Resuming from " << argv[3]; + solver.Solve(&caffe_net, argv[3]); + } else { + solver.Solve(&caffe_net); + } LOG(ERROR) << "Optimization Done."; return 0; diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index f0a5ebb9..9bbfd37e 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -65,13 +65,19 @@ class Net { void Update(); protected: + // Function to get misc parameters, e.g. the learning rate multiplier and + // weight decay. + void GetLearningRateAndWeightDecay(); + // Individual layers in the net vector > > layers_; vector layer_names_; + vector layer_need_backward_; // blobs stores the blobs that store intermediate results between the // layers. vector > > blobs_; vector blob_names_; + vector blob_need_backward_; // bottom_vecs stores the vectors containing the input for each layer vector*> > bottom_vecs_; vector > bottom_id_vecs_; diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 165869d4..50f4f93a 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -34,6 +34,7 @@ Net::Net(const NetParameter& param, bottom[i]->height(), bottom[i]->width())); blobs_.push_back(blob_pointer); blob_names_.push_back(blob_name); + blob_need_backward_.push_back(false); net_input_blob_indices_.push_back(i); blob_name_to_idx[blob_name] = i; available_blobs.insert(blob_name); @@ -49,17 +50,21 @@ Net::Net(const NetParameter& param, layers_.push_back(shared_ptr >(GetLayer(layer_param))); layer_names_.push_back(layer_param.name()); LOG(INFO) << "Creating Layer " << layer_param.name(); + bool need_backward = false; // Figure out this layer's input and output for (int j = 0; j < layer_connection.bottom_size(); ++j) { const string& blob_name = layer_connection.bottom(j); + const int blob_id = blob_name_to_idx[blob_name]; if (available_blobs.find(blob_name) == available_blobs.end()) { LOG(FATAL) << "Unknown blob input " << blob_name << " to layer" << j; } LOG(INFO) << layer_param.name() << " <- " << blob_name; bottom_vecs_[i].push_back( - blobs_[blob_name_to_idx[blob_name]].get()); - bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]); + blobs_[blob_id].get()); + bottom_id_vecs_[i].push_back(blob_id); + // If a blob needs backward, this layer should provide it. + need_backward |= blob_need_backward_[blob_id]; available_blobs.erase(blob_name); } for (int j = 0; j < layer_connection.top_size(); ++j) { @@ -83,12 +88,30 @@ Net::Net(const NetParameter& param, shared_ptr > blob_pointer(new Blob()); blobs_.push_back(blob_pointer); blob_names_.push_back(blob_name); + blob_need_backward_.push_back(false); blob_name_to_idx[blob_name] = blob_names_.size() - 1; available_blobs.insert(blob_name); top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get()); top_id_vecs_[i].push_back(blob_names_.size() - 1); } } + // After this layer is connected, set it up. + LOG(INFO) << "Setting up " << layer_names_[i]; + layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]); + // Check if this layer needs backward operation itself + for (int j = 0; j < layers_[i]->layer_param().blobs_lr_size(); ++j) { + need_backward |= (layers_[i]->layer_param().blobs_lr(j) > 0); + } + // Finally, set the backward flag + layer_need_backward_.push_back(need_backward); + if (need_backward) { + LOG(INFO) << layer_names_[i] << " needs backward computation."; + for (int j = 0; j < top_id_vecs_[i].size(); ++j) { + blob_need_backward_[top_id_vecs_[i][j]] = true; + } + } else { + LOG(INFO) << layer_names_[i] << " does not need backward computation."; + } } // In the end, all remaining blobs are considered output blobs. for (set::iterator it = available_blobs.begin(); @@ -97,11 +120,15 @@ Net::Net(const NetParameter& param, net_output_blob_indices_.push_back(blob_name_to_idx[*it]); net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); } + GetLearningRateAndWeightDecay(); + LOG(INFO) << "Network initialization done."; +} - LOG(INFO) << "Setting up the layers."; + +template +void Net::GetLearningRateAndWeightDecay() { + LOG(INFO) << "Collecting Learning Rate and Weight Decay."; for (int i = 0; i < layers_.size(); ++i) { - LOG(INFO) << "Setting up " << layer_names_[i]; - layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]); vector > >& layer_blobs = layers_[i]->blobs(); for (int j = 0; j < layer_blobs.size(); ++j) { params_.push_back(layer_blobs[j]); @@ -111,7 +138,7 @@ Net::Net(const NetParameter& param, CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size()); for (int j = 0; j < layer_blobs.size(); ++j) { float local_lr = layers_[i]->layer_param().blobs_lr(j); - CHECK_GT(local_lr, 0.); + CHECK_GE(local_lr, 0.); params_lr_.push_back(local_lr); } } else { @@ -125,7 +152,7 @@ Net::Net(const NetParameter& param, layer_blobs.size()); for (int j = 0; j < layer_blobs.size(); ++j) { float local_decay = layers_[i]->layer_param().weight_decay(j); - CHECK_GT(local_decay, 0.); + CHECK_GE(local_decay, 0.); params_weight_decay_.push_back(local_decay); } } else { @@ -139,7 +166,6 @@ Net::Net(const NetParameter& param, << top_vecs_[i][topid]->width(); } } - LOG(INFO) << "Network initialization done."; } template @@ -159,11 +185,12 @@ const vector*>& Net::Forward( template Dtype Net::Backward() { Dtype loss = 0; - // TODO(Yangqing): figure out those layers that do not need backward. for (int i = layers_.size() - 1; i >= 0; --i) { - Dtype layer_loss = layers_[i]->Backward( - top_vecs_[i], true, &bottom_vecs_[i]); - loss += layer_loss; + if (layer_need_backward_[i]) { + Dtype layer_loss = layers_[i]->Backward( + top_vecs_[i], true, &bottom_vecs_[i]); + loss += layer_loss; + } } return loss; }