зеркало из https://github.com/microsoft/caffe.git
need backward computation, and train_net resume point. Not debugged.
This commit is contained in:
Родитель
62089dd8da
Коммит
749ba87e48
|
@ -3,7 +3,7 @@
|
|||
// This is a simple script that allows one to quickly train a network whose
|
||||
// parameters are specified by text format protocol buffers.
|
||||
// Usage:
|
||||
// train_net net_proto_file solver_proto_file
|
||||
// train_net net_proto_file solver_proto_file [resume_point_file]
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
|
@ -28,7 +28,12 @@ int main(int argc, char** argv) {
|
|||
|
||||
LOG(ERROR) << "Starting Optimization";
|
||||
SGDSolver<float> solver(solver_param);
|
||||
solver.Solve(&caffe_net);
|
||||
if (argc == 4) {
|
||||
LOG(ERROR) << "Resuming from " << argv[3];
|
||||
solver.Solve(&caffe_net, argv[3]);
|
||||
} else {
|
||||
solver.Solve(&caffe_net);
|
||||
}
|
||||
LOG(ERROR) << "Optimization Done.";
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -65,13 +65,19 @@ class Net {
|
|||
void Update();
|
||||
|
||||
protected:
|
||||
// Function to get misc parameters, e.g. the learning rate multiplier and
|
||||
// weight decay.
|
||||
void GetLearningRateAndWeightDecay();
|
||||
|
||||
// Individual layers in the net
|
||||
vector<shared_ptr<Layer<Dtype> > > layers_;
|
||||
vector<string> layer_names_;
|
||||
vector<bool> layer_need_backward_;
|
||||
// blobs stores the blobs that store intermediate results between the
|
||||
// layers.
|
||||
vector<shared_ptr<Blob<Dtype> > > blobs_;
|
||||
vector<string> blob_names_;
|
||||
vector<bool> blob_need_backward_;
|
||||
// bottom_vecs stores the vectors containing the input for each layer
|
||||
vector<vector<Blob<Dtype>*> > bottom_vecs_;
|
||||
vector<vector<int> > bottom_id_vecs_;
|
||||
|
|
|
@ -34,6 +34,7 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
bottom[i]->height(), bottom[i]->width()));
|
||||
blobs_.push_back(blob_pointer);
|
||||
blob_names_.push_back(blob_name);
|
||||
blob_need_backward_.push_back(false);
|
||||
net_input_blob_indices_.push_back(i);
|
||||
blob_name_to_idx[blob_name] = i;
|
||||
available_blobs.insert(blob_name);
|
||||
|
@ -49,17 +50,21 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
|
||||
layer_names_.push_back(layer_param.name());
|
||||
LOG(INFO) << "Creating Layer " << layer_param.name();
|
||||
bool need_backward = false;
|
||||
// Figure out this layer's input and output
|
||||
for (int j = 0; j < layer_connection.bottom_size(); ++j) {
|
||||
const string& blob_name = layer_connection.bottom(j);
|
||||
const int blob_id = blob_name_to_idx[blob_name];
|
||||
if (available_blobs.find(blob_name) == available_blobs.end()) {
|
||||
LOG(FATAL) << "Unknown blob input " << blob_name <<
|
||||
" to layer" << j;
|
||||
}
|
||||
LOG(INFO) << layer_param.name() << " <- " << blob_name;
|
||||
bottom_vecs_[i].push_back(
|
||||
blobs_[blob_name_to_idx[blob_name]].get());
|
||||
bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]);
|
||||
blobs_[blob_id].get());
|
||||
bottom_id_vecs_[i].push_back(blob_id);
|
||||
// If a blob needs backward, this layer should provide it.
|
||||
need_backward |= blob_need_backward_[blob_id];
|
||||
available_blobs.erase(blob_name);
|
||||
}
|
||||
for (int j = 0; j < layer_connection.top_size(); ++j) {
|
||||
|
@ -83,12 +88,30 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
|
||||
blobs_.push_back(blob_pointer);
|
||||
blob_names_.push_back(blob_name);
|
||||
blob_need_backward_.push_back(false);
|
||||
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
|
||||
available_blobs.insert(blob_name);
|
||||
top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
|
||||
top_id_vecs_[i].push_back(blob_names_.size() - 1);
|
||||
}
|
||||
}
|
||||
// After this layer is connected, set it up.
|
||||
LOG(INFO) << "Setting up " << layer_names_[i];
|
||||
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
|
||||
// Check if this layer needs backward operation itself
|
||||
for (int j = 0; j < layers_[i]->layer_param().blobs_lr_size(); ++j) {
|
||||
need_backward |= (layers_[i]->layer_param().blobs_lr(j) > 0);
|
||||
}
|
||||
// Finally, set the backward flag
|
||||
layer_need_backward_.push_back(need_backward);
|
||||
if (need_backward) {
|
||||
LOG(INFO) << layer_names_[i] << " needs backward computation.";
|
||||
for (int j = 0; j < top_id_vecs_[i].size(); ++j) {
|
||||
blob_need_backward_[top_id_vecs_[i][j]] = true;
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << layer_names_[i] << " does not need backward computation.";
|
||||
}
|
||||
}
|
||||
// In the end, all remaining blobs are considered output blobs.
|
||||
for (set<string>::iterator it = available_blobs.begin();
|
||||
|
@ -97,11 +120,15 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
|
||||
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
|
||||
}
|
||||
GetLearningRateAndWeightDecay();
|
||||
LOG(INFO) << "Network initialization done.";
|
||||
}
|
||||
|
||||
LOG(INFO) << "Setting up the layers.";
|
||||
|
||||
template <typename Dtype>
|
||||
void Net<Dtype>::GetLearningRateAndWeightDecay() {
|
||||
LOG(INFO) << "Collecting Learning Rate and Weight Decay.";
|
||||
for (int i = 0; i < layers_.size(); ++i) {
|
||||
LOG(INFO) << "Setting up " << layer_names_[i];
|
||||
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
|
||||
vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs();
|
||||
for (int j = 0; j < layer_blobs.size(); ++j) {
|
||||
params_.push_back(layer_blobs[j]);
|
||||
|
@ -111,7 +138,7 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size());
|
||||
for (int j = 0; j < layer_blobs.size(); ++j) {
|
||||
float local_lr = layers_[i]->layer_param().blobs_lr(j);
|
||||
CHECK_GT(local_lr, 0.);
|
||||
CHECK_GE(local_lr, 0.);
|
||||
params_lr_.push_back(local_lr);
|
||||
}
|
||||
} else {
|
||||
|
@ -125,7 +152,7 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
layer_blobs.size());
|
||||
for (int j = 0; j < layer_blobs.size(); ++j) {
|
||||
float local_decay = layers_[i]->layer_param().weight_decay(j);
|
||||
CHECK_GT(local_decay, 0.);
|
||||
CHECK_GE(local_decay, 0.);
|
||||
params_weight_decay_.push_back(local_decay);
|
||||
}
|
||||
} else {
|
||||
|
@ -139,7 +166,6 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
<< top_vecs_[i][topid]->width();
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Network initialization done.";
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
|
@ -159,11 +185,12 @@ const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
|
|||
template <typename Dtype>
|
||||
Dtype Net<Dtype>::Backward() {
|
||||
Dtype loss = 0;
|
||||
// TODO(Yangqing): figure out those layers that do not need backward.
|
||||
for (int i = layers_.size() - 1; i >= 0; --i) {
|
||||
Dtype layer_loss = layers_[i]->Backward(
|
||||
top_vecs_[i], true, &bottom_vecs_[i]);
|
||||
loss += layer_loss;
|
||||
if (layer_need_backward_[i]) {
|
||||
Dtype layer_loss = layers_[i]->Backward(
|
||||
top_vecs_[i], true, &bottom_vecs_[i]);
|
||||
loss += layer_loss;
|
||||
}
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче