solver restructuring: now all prototxt are specified in the solver protocol buffer

This commit is contained in:
Yangqing Jia 2013-10-31 16:52:22 -07:00
Родитель 25a865cd8b
Коммит 82b912be84
11 изменённых файлов: 146 добавлений и 147 удалений

Просмотреть файл

@ -47,7 +47,7 @@ LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
glog mkl_rt mkl_intel_thread leveldb snappy pthread
WARNINGS := -Wall
COMMON_FLAGS := $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC -O2 $(COMMON_FLAGS)
NVCCFLAGS := -Xcompiler -fPIC -O2 $(COMMON_FLAGS)
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \

Просмотреть файл

@ -0,0 +1,12 @@
train_net: "data/lenet.prototxt"
test_net: "data/lenet_test.prototxt"
base_lr: 0.01
lr_policy: "inv"
gamma: 0.0001
power: 0.75
display: 100
max_iter: 5000
momentum: 0.9
weight_decay: 0.0005
test_iter: 100
test_interval: 500

Просмотреть файл

@ -1,96 +0,0 @@
// Copyright 2013 Yangqing Jia
// This example shows how to run a modified version of LeNet using Caffe.
#include <cuda_runtime.h>
#include <fcntl.h>
#include <google/protobuf/text_format.h>
#include <cstring>
#include <iostream>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/filler.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/solver.hpp"
using namespace caffe;
int main(int argc, char** argv) {
if (argc < 3) {
std::cout << "Usage:" << std::endl;
std::cout << "demo_mnist.bin train_file test_file [CPU/GPU]" << std::endl;
return 0;
}
google::InitGoogleLogging(argv[0]);
Caffe::DeviceQuery();
if (argc == 4 && strcmp(argv[3], "GPU") == 0) {
LOG(ERROR) << "Using GPU";
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
}
// Start training
Caffe::set_phase(Caffe::TRAIN);
NetParameter net_param;
ReadProtoFromTextFile(argv[1],
&net_param);
vector<Blob<float>*> bottom_vec;
Net<float> caffe_net(net_param, bottom_vec);
// Run the network without training.
LOG(ERROR) << "Performing Forward";
caffe_net.Forward(bottom_vec);
LOG(ERROR) << "Performing Backward";
LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
SolverParameter solver_param;
// Solver Parameters are hard-coded in this case, but you can write a
// SolverParameter protocol buffer to specify all these values.
solver_param.set_base_lr(0.01);
solver_param.set_display(100);
solver_param.set_max_iter(5000);
solver_param.set_lr_policy("inv");
solver_param.set_gamma(0.0001);
solver_param.set_power(0.75);
solver_param.set_momentum(0.9);
solver_param.set_weight_decay(0.0005);
LOG(ERROR) << "Starting Optimization";
SGDSolver<float> solver(solver_param);
solver.Solve(&caffe_net);
LOG(ERROR) << "Optimization Done.";
// Write the trained network to a NetParameter protobuf. If you are training
// the model and saving it for later, this is what you want to serialize and
// store.
NetParameter trained_net_param;
caffe_net.ToProto(&trained_net_param);
// Now, let's starting doing testing.
Caffe::set_phase(Caffe::TEST);
// Using the testing data to test the accuracy.
NetParameter test_net_param;
ReadProtoFromTextFile(argv[2], &test_net_param);
Net<float> caffe_test_net(test_net_param, bottom_vec);
caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
double test_accuracy = 0;
int batch_size = test_net_param.layers(0).layer().batchsize();
for (int i = 0; i < 10000 / batch_size; ++i) {
const vector<Blob<float>*>& result =
caffe_test_net.Forward(bottom_vec);
test_accuracy += result[0]->cpu_data()[0];
}
test_accuracy /= 10000 / batch_size;
LOG(ERROR) << "Test accuracy:" << test_accuracy;
return 0;
}

Просмотреть файл

@ -14,33 +14,27 @@
using namespace caffe;
int main(int argc, char** argv) {
if (argc < 3) {
LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file "
<< "[resume_point_file]";
::google::InitGoogleLogging(argv[0]);
if (argc < 2) {
LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
return 0;
}
cudaSetDevice(0);
Caffe::SetDevice(0);
Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TRAIN);
NetParameter net_param;
ReadProtoFromTextFile(argv[1], &net_param);
vector<Blob<float>*> bottom_vec;
Net<float> caffe_net(net_param, bottom_vec);
SolverParameter solver_param;
ReadProtoFromTextFile(argv[2], &solver_param);
ReadProtoFromTextFile(argv[1], &solver_param);
LOG(ERROR) << "Starting Optimization";
LOG(INFO) << "Starting Optimization";
SGDSolver<float> solver(solver_param);
if (argc == 4) {
LOG(ERROR) << "Resuming from " << argv[3];
solver.Solve(&caffe_net, argv[3]);
if (argc == 3) {
LOG(INFO) << "Resuming from " << argv[2];
solver.Solve(argv[2]);
} else {
solver.Solve(&caffe_net);
solver.Solve();
}
LOG(ERROR) << "Optimization Done.";
LOG(INFO) << "Optimization Done.";
return 0;
}

Просмотреть файл

@ -94,6 +94,9 @@ class Caffe {
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both MKL and curand
static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
static void SetDevice(const int device_id);
// Prints the current GPU status.
static void DeviceQuery();

Просмотреть файл

@ -10,11 +10,11 @@ namespace caffe {
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param)
: param_(param) {}
explicit Solver(const SolverParameter& param);
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
void Solve(Net<Dtype>* net, char* state_file = NULL);
void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
virtual ~Solver() {}
protected:
@ -28,15 +28,18 @@ class Solver {
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
// The test routine
void Test();
virtual void SnapshotSolverState(SolverState* state) = 0;
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
// function that restores the state from a SolverState protocol buffer.
void Restore(char* state_file);
void Restore(const char* resume_file);
virtual void RestoreSolverState(const SolverState& state) = 0;
SolverParameter param_;
int iter_;
Net<Dtype>* net_;
Net<Dtype>* test_net_;
DISABLE_COPY_AND_ASSIGN(Solver);
};

Просмотреть файл

@ -74,6 +74,24 @@ void Caffe::set_random_seed(const unsigned int seed) {
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
}
void Caffe::SetDevice(const int device_id) {
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
if (current_device == device_id) {
return;
}
if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
if (Get().curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
}
CUDA_CHECK(cudaSetDevice(device_id));
CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
time(NULL)));
}
void Caffe::DeviceQuery() {
cudaDeviceProp prop;
int device;

Просмотреть файл

@ -101,7 +101,7 @@ void* DataLayerPrefetch(void* layer_pointer) {
layer->iter_->Next();
if (!layer->iter_->Valid()) {
// We have reached the end. Restart from the first.
LOG(INFO) << "Restarting data prefetching from start.";
DLOG(INFO) << "Restarting data prefetching from start.";
layer->iter_->SeekToFirst();
}
}
@ -180,10 +180,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
prefetch_data_->mutable_cpu_data();
prefetch_label_->mutable_cpu_data();
data_mean_.cpu_data();
// LOG(INFO) << "Initializing prefetch";
DLOG(INFO) << "Initializing prefetch";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
// LOG(INFO) << "Prefetch initialized.";
DLOG(INFO) << "Prefetch initialized.";
}
template <typename Dtype>

Просмотреть файл

@ -121,7 +121,7 @@ Net<Dtype>::Net(const NetParameter& param,
// In the end, all remaining blobs are considered output blobs.
for (set<string>::iterator it = available_blobs.begin();
it != available_blobs.end(); ++it) {
LOG(ERROR) << "This network produces output " << *it;
LOG(INFO) << "This network produces output " << *it;
net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
}
@ -207,10 +207,10 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
++target_layer_id;
}
if (target_layer_id == layer_names_.size()) {
LOG(INFO) << "Ignoring source layer " << source_layer_name;
DLOG(INFO) << "Ignoring source layer " << source_layer_name;
continue;
}
LOG(INFO) << "Loading source layer " << source_layer_name;
DLOG(INFO) << "Loading source layer " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_[target_layer_id]->blobs();
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
@ -233,7 +233,7 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
param->add_input(blob_names_[net_input_blob_indices_[i]]);
}
LOG(INFO) << "Serializing " << layers_.size() << " layers";
DLOG(INFO) << "Serializing " << layers_.size() << " layers";
for (int i = 0; i < layers_.size(); ++i) {
LayerConnection* layer_connection = param->add_layers();
for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {

Просмотреть файл

@ -94,25 +94,28 @@ message NetParameter {
}
message SolverParameter {
optional float base_lr = 1; // The base learning rate
optional string train_net = 1; // The proto file for the training net.
optional string test_net = 2; // The proto file for the testing net.
// The number of iterations for each testing phase.
optional int32 test_iter = 3 [ default = 0 ];
// The number of iterations between two testing phases.
optional int32 test_interval = 4 [ default = 0 ];
optional float base_lr = 5; // The base learning rate
// the number of iterations between displaying info. If display = 0, no info
// will be displayed.
optional int32 display = 2;
optional int32 max_iter = 3; // the maximum number of iterations
optional int32 snapshot = 4 [default = 0]; // The snapshot interval
optional string lr_policy = 5; // The learning rate decay policy.
optional float min_lr = 6 [default = 0]; // The mininum learning rate
optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
optional float gamma = 8; // The parameter to compute the learning rate.
optional float power = 9; // The parameter to compute the learning rate.
optional float momentum = 10; // The momentum value.
optional float weight_decay = 11; // The weight decay.
optional int32 stepsize = 12; // the stepsize for learning rate policy "step"
optional string snapshot_prefix = 13; // The prefix for the snapshot.
optional int32 display = 6;
optional int32 max_iter = 7; // the maximum number of iterations
optional string lr_policy = 8; // The learning rate decay policy.
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
optional float weight_decay = 12; // The weight decay.
optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 14 [ default = false];
optional bool snapshot_diff = 16 [ default = false];
}
// A message that stores the solver snapshots

Просмотреть файл

@ -18,8 +18,31 @@ using std::min;
namespace caffe {
template <typename Dtype>
void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
net_ = net;
Solver<Dtype>::Solver(const SolverParameter& param)
: param_(param), net_(NULL), test_net_(NULL) {
// Scaffolding code
NetParameter train_net_param;
ReadProtoFromTextFile(param_.train_net(), &train_net_param);
// For the training network, there should be no input - so we simply create
// a dummy bottom_vec instance to initialize the networks.
vector<Blob<Dtype>*> bottom_vec;
LOG(INFO) << "Creating training net.";
net_ = new Net<Dtype>(train_net_param, bottom_vec);
if (param_.has_test_net()) {
LOG(INFO) << "Creating testing net.";
NetParameter test_net_param;
ReadProtoFromTextFile(param_.test_net(), &test_net_param);
test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
CHECK_GT(param_.test_iter(), 0);
CHECK_GT(param_.test_interval(), 0);
}
LOG(INFO) << "Solver scaffolding done.";
}
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
PreSolve();
@ -42,13 +65,54 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
// We need to set phase to test before running.
Caffe::set_phase(Caffe::TEST);
Test();
Caffe::set_phase(Caffe::TRAIN);
}
}
LOG(INFO) << "Optimization Done.";
}
template <typename Dtype>
void Solver<Dtype>::Test() {
LOG(INFO) << "Testing net";
NetParameter net_param;
net_->ToProto(&net_param);
CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
vector<Dtype> test_score;
vector<Blob<Dtype>*> bottom_vec;
for (int i = 0; i < param_.test_iter(); ++i) {
const vector<Blob<Dtype>*>& result =
test_net_->Forward(bottom_vec);
if (i == 0) {
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score.push_back(result_vec[k]);
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score[idx++] += result_vec[k];
}
}
}
}
for (int i = 0; i < test_score.size(); ++i) {
LOG(INFO) << "Test score #" << i << ": "
<< test_score[i] / param_.test_iter();
}
}
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param;
@ -70,7 +134,7 @@ void Solver<Dtype>::Snapshot() {
}
template <typename Dtype>
void Solver<Dtype>::Restore(char* state_file) {
void Solver<Dtype>::Restore(const char* state_file) {
SolverState state;
NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state);
@ -108,8 +172,6 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
rate = min(max(rate, Dtype(this->param_.min_lr())),
Dtype(this->param_.max_lr()));
return rate;
}
@ -136,7 +198,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
// get the learning rate
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();