solver restructuring: now all prototxt are specified in the solver protocol buffer

This commit is contained in:
Yangqing Jia 2013-10-31 16:52:22 -07:00
Родитель 25a865cd8b
Коммит 82b912be84
11 изменённых файлов: 146 добавлений и 147 удалений

Просмотреть файл

@ -47,7 +47,7 @@ LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
glog mkl_rt mkl_intel_thread leveldb snappy pthread glog mkl_rt mkl_intel_thread leveldb snappy pthread
WARNINGS := -Wall WARNINGS := -Wall
COMMON_FLAGS := $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC -O2 $(COMMON_FLAGS) CXXFLAGS += -pthread -fPIC -O2 $(COMMON_FLAGS)
NVCCFLAGS := -Xcompiler -fPIC -O2 $(COMMON_FLAGS) NVCCFLAGS := -Xcompiler -fPIC -O2 $(COMMON_FLAGS)
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \ LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \

Просмотреть файл

@ -0,0 +1,12 @@
train_net: "data/lenet.prototxt"
test_net: "data/lenet_test.prototxt"
base_lr: 0.01
lr_policy: "inv"
gamma: 0.0001
power: 0.75
display: 100
max_iter: 5000
momentum: 0.9
weight_decay: 0.0005
test_iter: 100
test_interval: 500

Просмотреть файл

@ -1,96 +0,0 @@
// Copyright 2013 Yangqing Jia
// This example shows how to run a modified version of LeNet using Caffe.
#include <cuda_runtime.h>
#include <fcntl.h>
#include <google/protobuf/text_format.h>
#include <cstring>
#include <iostream>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/filler.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/solver.hpp"
using namespace caffe;
int main(int argc, char** argv) {
if (argc < 3) {
std::cout << "Usage:" << std::endl;
std::cout << "demo_mnist.bin train_file test_file [CPU/GPU]" << std::endl;
return 0;
}
google::InitGoogleLogging(argv[0]);
Caffe::DeviceQuery();
if (argc == 4 && strcmp(argv[3], "GPU") == 0) {
LOG(ERROR) << "Using GPU";
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
}
// Start training
Caffe::set_phase(Caffe::TRAIN);
NetParameter net_param;
ReadProtoFromTextFile(argv[1],
&net_param);
vector<Blob<float>*> bottom_vec;
Net<float> caffe_net(net_param, bottom_vec);
// Run the network without training.
LOG(ERROR) << "Performing Forward";
caffe_net.Forward(bottom_vec);
LOG(ERROR) << "Performing Backward";
LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
SolverParameter solver_param;
// Solver Parameters are hard-coded in this case, but you can write a
// SolverParameter protocol buffer to specify all these values.
solver_param.set_base_lr(0.01);
solver_param.set_display(100);
solver_param.set_max_iter(5000);
solver_param.set_lr_policy("inv");
solver_param.set_gamma(0.0001);
solver_param.set_power(0.75);
solver_param.set_momentum(0.9);
solver_param.set_weight_decay(0.0005);
LOG(ERROR) << "Starting Optimization";
SGDSolver<float> solver(solver_param);
solver.Solve(&caffe_net);
LOG(ERROR) << "Optimization Done.";
// Write the trained network to a NetParameter protobuf. If you are training
// the model and saving it for later, this is what you want to serialize and
// store.
NetParameter trained_net_param;
caffe_net.ToProto(&trained_net_param);
// Now, let's starting doing testing.
Caffe::set_phase(Caffe::TEST);
// Using the testing data to test the accuracy.
NetParameter test_net_param;
ReadProtoFromTextFile(argv[2], &test_net_param);
Net<float> caffe_test_net(test_net_param, bottom_vec);
caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
double test_accuracy = 0;
int batch_size = test_net_param.layers(0).layer().batchsize();
for (int i = 0; i < 10000 / batch_size; ++i) {
const vector<Blob<float>*>& result =
caffe_test_net.Forward(bottom_vec);
test_accuracy += result[0]->cpu_data()[0];
}
test_accuracy /= 10000 / batch_size;
LOG(ERROR) << "Test accuracy:" << test_accuracy;
return 0;
}

Просмотреть файл

@ -14,33 +14,27 @@
using namespace caffe; using namespace caffe;
int main(int argc, char** argv) { int main(int argc, char** argv) {
if (argc < 3) { ::google::InitGoogleLogging(argv[0]);
LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file " if (argc < 2) {
<< "[resume_point_file]"; LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
return 0; return 0;
} }
cudaSetDevice(0); Caffe::SetDevice(0);
Caffe::set_mode(Caffe::GPU); Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TRAIN);
NetParameter net_param;
ReadProtoFromTextFile(argv[1], &net_param);
vector<Blob<float>*> bottom_vec;
Net<float> caffe_net(net_param, bottom_vec);
SolverParameter solver_param; SolverParameter solver_param;
ReadProtoFromTextFile(argv[2], &solver_param); ReadProtoFromTextFile(argv[1], &solver_param);
LOG(ERROR) << "Starting Optimization"; LOG(INFO) << "Starting Optimization";
SGDSolver<float> solver(solver_param); SGDSolver<float> solver(solver_param);
if (argc == 4) { if (argc == 3) {
LOG(ERROR) << "Resuming from " << argv[3]; LOG(INFO) << "Resuming from " << argv[2];
solver.Solve(&caffe_net, argv[3]); solver.Solve(argv[2]);
} else { } else {
solver.Solve(&caffe_net); solver.Solve();
} }
LOG(ERROR) << "Optimization Done."; LOG(INFO) << "Optimization Done.";
return 0; return 0;
} }

Просмотреть файл

@ -94,6 +94,9 @@ class Caffe {
inline static void set_phase(Phase phase) { Get().phase_ = phase; } inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both MKL and curand // Sets the random seed of both MKL and curand
static void set_random_seed(const unsigned int seed); static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
static void SetDevice(const int device_id);
// Prints the current GPU status. // Prints the current GPU status.
static void DeviceQuery(); static void DeviceQuery();

Просмотреть файл

@ -10,11 +10,11 @@ namespace caffe {
template <typename Dtype> template <typename Dtype>
class Solver { class Solver {
public: public:
explicit Solver(const SolverParameter& param) explicit Solver(const SolverParameter& param);
: param_(param) {}
// The main entry of the solver function. In default, iter will be zero. Pass // The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net. // in a non-zero iter number to resume training for a pre-trained net.
void Solve(Net<Dtype>* net, char* state_file = NULL); void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
virtual ~Solver() {} virtual ~Solver() {}
protected: protected:
@ -28,15 +28,18 @@ class Solver {
// function that produces a SolverState protocol buffer that needs to be // function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net. // written to disk together with the learned net.
void Snapshot(); void Snapshot();
// The test routine
void Test();
virtual void SnapshotSolverState(SolverState* state) = 0; virtual void SnapshotSolverState(SolverState* state) = 0;
// The Restore function implements how one should restore the solver to a // The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState() // previously snapshotted state. You should implement the RestoreSolverState()
// function that restores the state from a SolverState protocol buffer. // function that restores the state from a SolverState protocol buffer.
void Restore(char* state_file); void Restore(const char* resume_file);
virtual void RestoreSolverState(const SolverState& state) = 0; virtual void RestoreSolverState(const SolverState& state) = 0;
SolverParameter param_; SolverParameter param_;
int iter_; int iter_;
Net<Dtype>* net_; Net<Dtype>* net_;
Net<Dtype>* test_net_;
DISABLE_COPY_AND_ASSIGN(Solver); DISABLE_COPY_AND_ASSIGN(Solver);
}; };

Просмотреть файл

@ -74,6 +74,24 @@ void Caffe::set_random_seed(const unsigned int seed) {
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed)); VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
} }
void Caffe::SetDevice(const int device_id) {
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
if (current_device == device_id) {
return;
}
if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
if (Get().curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
}
CUDA_CHECK(cudaSetDevice(device_id));
CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
time(NULL)));
}
void Caffe::DeviceQuery() { void Caffe::DeviceQuery() {
cudaDeviceProp prop; cudaDeviceProp prop;
int device; int device;

Просмотреть файл

@ -101,7 +101,7 @@ void* DataLayerPrefetch(void* layer_pointer) {
layer->iter_->Next(); layer->iter_->Next();
if (!layer->iter_->Valid()) { if (!layer->iter_->Valid()) {
// We have reached the end. Restart from the first. // We have reached the end. Restart from the first.
LOG(INFO) << "Restarting data prefetching from start."; DLOG(INFO) << "Restarting data prefetching from start.";
layer->iter_->SeekToFirst(); layer->iter_->SeekToFirst();
} }
} }
@ -180,10 +180,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
prefetch_data_->mutable_cpu_data(); prefetch_data_->mutable_cpu_data();
prefetch_label_->mutable_cpu_data(); prefetch_label_->mutable_cpu_data();
data_mean_.cpu_data(); data_mean_.cpu_data();
// LOG(INFO) << "Initializing prefetch"; DLOG(INFO) << "Initializing prefetch";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed."; reinterpret_cast<void*>(this))) << "Pthread execution failed.";
// LOG(INFO) << "Prefetch initialized."; DLOG(INFO) << "Prefetch initialized.";
} }
template <typename Dtype> template <typename Dtype>

Просмотреть файл

@ -121,7 +121,7 @@ Net<Dtype>::Net(const NetParameter& param,
// In the end, all remaining blobs are considered output blobs. // In the end, all remaining blobs are considered output blobs.
for (set<string>::iterator it = available_blobs.begin(); for (set<string>::iterator it = available_blobs.begin();
it != available_blobs.end(); ++it) { it != available_blobs.end(); ++it) {
LOG(ERROR) << "This network produces output " << *it; LOG(INFO) << "This network produces output " << *it;
net_output_blob_indices_.push_back(blob_name_to_idx[*it]); net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
} }
@ -207,10 +207,10 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
++target_layer_id; ++target_layer_id;
} }
if (target_layer_id == layer_names_.size()) { if (target_layer_id == layer_names_.size()) {
LOG(INFO) << "Ignoring source layer " << source_layer_name; DLOG(INFO) << "Ignoring source layer " << source_layer_name;
continue; continue;
} }
LOG(INFO) << "Loading source layer " << source_layer_name; DLOG(INFO) << "Loading source layer " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs = vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_[target_layer_id]->blobs(); layers_[target_layer_id]->blobs();
CHECK_EQ(target_blobs.size(), source_layer.blobs_size()) CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
@ -233,7 +233,7 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
for (int i = 0; i < net_input_blob_indices_.size(); ++i) { for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
param->add_input(blob_names_[net_input_blob_indices_[i]]); param->add_input(blob_names_[net_input_blob_indices_[i]]);
} }
LOG(INFO) << "Serializing " << layers_.size() << " layers"; DLOG(INFO) << "Serializing " << layers_.size() << " layers";
for (int i = 0; i < layers_.size(); ++i) { for (int i = 0; i < layers_.size(); ++i) {
LayerConnection* layer_connection = param->add_layers(); LayerConnection* layer_connection = param->add_layers();
for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) { for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {

Просмотреть файл

@ -94,25 +94,28 @@ message NetParameter {
} }
message SolverParameter { message SolverParameter {
optional float base_lr = 1; // The base learning rate optional string train_net = 1; // The proto file for the training net.
optional string test_net = 2; // The proto file for the testing net.
// The number of iterations for each testing phase.
optional int32 test_iter = 3 [ default = 0 ];
// The number of iterations between two testing phases.
optional int32 test_interval = 4 [ default = 0 ];
optional float base_lr = 5; // The base learning rate
// the number of iterations between displaying info. If display = 0, no info // the number of iterations between displaying info. If display = 0, no info
// will be displayed. // will be displayed.
optional int32 display = 2; optional int32 display = 6;
optional int32 max_iter = 3; // the maximum number of iterations optional int32 max_iter = 7; // the maximum number of iterations
optional int32 snapshot = 4 [default = 0]; // The snapshot interval optional string lr_policy = 8; // The learning rate decay policy.
optional string lr_policy = 5; // The learning rate decay policy. optional float gamma = 9; // The parameter to compute the learning rate.
optional float min_lr = 6 [default = 0]; // The mininum learning rate optional float power = 10; // The parameter to compute the learning rate.
optional float max_lr = 7 [default = 1e10]; // The maximum learning rate optional float momentum = 11; // The momentum value.
optional float gamma = 8; // The parameter to compute the learning rate. optional float weight_decay = 12; // The weight decay.
optional float power = 9; // The parameter to compute the learning rate. optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
optional float momentum = 10; // The momentum value. optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional float weight_decay = 11; // The weight decay. optional string snapshot_prefix = 15; // The prefix for the snapshot.
optional int32 stepsize = 12; // the stepsize for learning rate policy "step"
optional string snapshot_prefix = 13; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help // whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger. // debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 14 [ default = false]; optional bool snapshot_diff = 16 [ default = false];
} }
// A message that stores the solver snapshots // A message that stores the solver snapshots

Просмотреть файл

@ -18,8 +18,31 @@ using std::min;
namespace caffe { namespace caffe {
template <typename Dtype> template <typename Dtype>
void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) { Solver<Dtype>::Solver(const SolverParameter& param)
net_ = net; : param_(param), net_(NULL), test_net_(NULL) {
// Scaffolding code
NetParameter train_net_param;
ReadProtoFromTextFile(param_.train_net(), &train_net_param);
// For the training network, there should be no input - so we simply create
// a dummy bottom_vec instance to initialize the networks.
vector<Blob<Dtype>*> bottom_vec;
LOG(INFO) << "Creating training net.";
net_ = new Net<Dtype>(train_net_param, bottom_vec);
if (param_.has_test_net()) {
LOG(INFO) << "Creating testing net.";
NetParameter test_net_param;
ReadProtoFromTextFile(param_.test_net(), &test_net_param);
test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
CHECK_GT(param_.test_iter(), 0);
CHECK_GT(param_.test_interval(), 0);
}
LOG(INFO) << "Solver scaffolding done.";
}
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Solving " << net_->name();
PreSolve(); PreSolve();
@ -42,13 +65,54 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
Snapshot(); Snapshot();
} }
if (param_.display() && iter_ % param_.display() == 0) { if (param_.display() && iter_ % param_.display() == 0) {
LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss; LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
// We need to set phase to test before running.
Caffe::set_phase(Caffe::TEST);
Test();
Caffe::set_phase(Caffe::TRAIN);
} }
} }
LOG(INFO) << "Optimization Done."; LOG(INFO) << "Optimization Done.";
} }
template <typename Dtype>
void Solver<Dtype>::Test() {
LOG(INFO) << "Testing net";
NetParameter net_param;
net_->ToProto(&net_param);
CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
vector<Dtype> test_score;
vector<Blob<Dtype>*> bottom_vec;
for (int i = 0; i < param_.test_iter(); ++i) {
const vector<Blob<Dtype>*>& result =
test_net_->Forward(bottom_vec);
if (i == 0) {
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score.push_back(result_vec[k]);
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score[idx++] += result_vec[k];
}
}
}
}
for (int i = 0; i < test_score.size(); ++i) {
LOG(INFO) << "Test score #" << i << ": "
<< test_score[i] / param_.test_iter();
}
}
template <typename Dtype> template <typename Dtype>
void Solver<Dtype>::Snapshot() { void Solver<Dtype>::Snapshot() {
NetParameter net_param; NetParameter net_param;
@ -70,7 +134,7 @@ void Solver<Dtype>::Snapshot() {
} }
template <typename Dtype> template <typename Dtype>
void Solver<Dtype>::Restore(char* state_file) { void Solver<Dtype>::Restore(const char* state_file) {
SolverState state; SolverState state;
NetParameter net_param; NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state); ReadProtoFromBinaryFile(state_file, &state);
@ -108,8 +172,6 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
} else { } else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
} }
rate = min(max(rate, Dtype(this->param_.min_lr())),
Dtype(this->param_.max_lr()));
return rate; return rate;
} }
@ -136,7 +198,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
// get the learning rate // get the learning rate
Dtype rate = GetLearningRate(); Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) { if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate; LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
} }
Dtype momentum = this->param_.momentum(); Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay(); Dtype weight_decay = this->param_.weight_decay();