зеркало из https://github.com/microsoft/caffe.git
solver restructuring: now all prototxt are specified in the solver protocol buffer
This commit is contained in:
Родитель
25a865cd8b
Коммит
82b912be84
2
Makefile
2
Makefile
|
@ -47,7 +47,7 @@ LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
|
|||
glog mkl_rt mkl_intel_thread leveldb snappy pthread
|
||||
WARNINGS := -Wall
|
||||
|
||||
COMMON_FLAGS := $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
|
||||
COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
|
||||
CXXFLAGS += -pthread -fPIC -O2 $(COMMON_FLAGS)
|
||||
NVCCFLAGS := -Xcompiler -fPIC -O2 $(COMMON_FLAGS)
|
||||
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
train_net: "data/lenet.prototxt"
|
||||
test_net: "data/lenet_test.prototxt"
|
||||
base_lr: 0.01
|
||||
lr_policy: "inv"
|
||||
gamma: 0.0001
|
||||
power: 0.75
|
||||
display: 100
|
||||
max_iter: 5000
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0005
|
||||
test_iter: 100
|
||||
test_interval: 500
|
|
@ -1,96 +0,0 @@
|
|||
// Copyright 2013 Yangqing Jia
|
||||
// This example shows how to run a modified version of LeNet using Caffe.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fcntl.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/net.hpp"
|
||||
#include "caffe/filler.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
#include "caffe/util/io.hpp"
|
||||
#include "caffe/solver.hpp"
|
||||
|
||||
using namespace caffe;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage:" << std::endl;
|
||||
std::cout << "demo_mnist.bin train_file test_file [CPU/GPU]" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
Caffe::DeviceQuery();
|
||||
|
||||
if (argc == 4 && strcmp(argv[3], "GPU") == 0) {
|
||||
LOG(ERROR) << "Using GPU";
|
||||
Caffe::set_mode(Caffe::GPU);
|
||||
} else {
|
||||
LOG(ERROR) << "Using CPU";
|
||||
Caffe::set_mode(Caffe::CPU);
|
||||
}
|
||||
|
||||
// Start training
|
||||
Caffe::set_phase(Caffe::TRAIN);
|
||||
|
||||
NetParameter net_param;
|
||||
ReadProtoFromTextFile(argv[1],
|
||||
&net_param);
|
||||
vector<Blob<float>*> bottom_vec;
|
||||
Net<float> caffe_net(net_param, bottom_vec);
|
||||
|
||||
// Run the network without training.
|
||||
LOG(ERROR) << "Performing Forward";
|
||||
caffe_net.Forward(bottom_vec);
|
||||
LOG(ERROR) << "Performing Backward";
|
||||
LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
|
||||
|
||||
SolverParameter solver_param;
|
||||
// Solver Parameters are hard-coded in this case, but you can write a
|
||||
// SolverParameter protocol buffer to specify all these values.
|
||||
solver_param.set_base_lr(0.01);
|
||||
solver_param.set_display(100);
|
||||
solver_param.set_max_iter(5000);
|
||||
solver_param.set_lr_policy("inv");
|
||||
solver_param.set_gamma(0.0001);
|
||||
solver_param.set_power(0.75);
|
||||
solver_param.set_momentum(0.9);
|
||||
solver_param.set_weight_decay(0.0005);
|
||||
|
||||
LOG(ERROR) << "Starting Optimization";
|
||||
SGDSolver<float> solver(solver_param);
|
||||
solver.Solve(&caffe_net);
|
||||
LOG(ERROR) << "Optimization Done.";
|
||||
|
||||
// Write the trained network to a NetParameter protobuf. If you are training
|
||||
// the model and saving it for later, this is what you want to serialize and
|
||||
// store.
|
||||
NetParameter trained_net_param;
|
||||
caffe_net.ToProto(&trained_net_param);
|
||||
|
||||
// Now, let's starting doing testing.
|
||||
Caffe::set_phase(Caffe::TEST);
|
||||
|
||||
// Using the testing data to test the accuracy.
|
||||
NetParameter test_net_param;
|
||||
ReadProtoFromTextFile(argv[2], &test_net_param);
|
||||
Net<float> caffe_test_net(test_net_param, bottom_vec);
|
||||
caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
|
||||
|
||||
double test_accuracy = 0;
|
||||
int batch_size = test_net_param.layers(0).layer().batchsize();
|
||||
for (int i = 0; i < 10000 / batch_size; ++i) {
|
||||
const vector<Blob<float>*>& result =
|
||||
caffe_test_net.Forward(bottom_vec);
|
||||
test_accuracy += result[0]->cpu_data()[0];
|
||||
}
|
||||
test_accuracy /= 10000 / batch_size;
|
||||
LOG(ERROR) << "Test accuracy:" << test_accuracy;
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -14,33 +14,27 @@
|
|||
using namespace caffe;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 3) {
|
||||
LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file "
|
||||
<< "[resume_point_file]";
|
||||
::google::InitGoogleLogging(argv[0]);
|
||||
if (argc < 2) {
|
||||
LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaSetDevice(0);
|
||||
Caffe::SetDevice(0);
|
||||
Caffe::set_mode(Caffe::GPU);
|
||||
Caffe::set_phase(Caffe::TRAIN);
|
||||
|
||||
NetParameter net_param;
|
||||
ReadProtoFromTextFile(argv[1], &net_param);
|
||||
vector<Blob<float>*> bottom_vec;
|
||||
Net<float> caffe_net(net_param, bottom_vec);
|
||||
|
||||
SolverParameter solver_param;
|
||||
ReadProtoFromTextFile(argv[2], &solver_param);
|
||||
ReadProtoFromTextFile(argv[1], &solver_param);
|
||||
|
||||
LOG(ERROR) << "Starting Optimization";
|
||||
LOG(INFO) << "Starting Optimization";
|
||||
SGDSolver<float> solver(solver_param);
|
||||
if (argc == 4) {
|
||||
LOG(ERROR) << "Resuming from " << argv[3];
|
||||
solver.Solve(&caffe_net, argv[3]);
|
||||
if (argc == 3) {
|
||||
LOG(INFO) << "Resuming from " << argv[2];
|
||||
solver.Solve(argv[2]);
|
||||
} else {
|
||||
solver.Solve(&caffe_net);
|
||||
solver.Solve();
|
||||
}
|
||||
LOG(ERROR) << "Optimization Done.";
|
||||
LOG(INFO) << "Optimization Done.";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -94,6 +94,9 @@ class Caffe {
|
|||
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
|
||||
// Sets the random seed of both MKL and curand
|
||||
static void set_random_seed(const unsigned int seed);
|
||||
// Sets the device. Since we have cublas and curand stuff, set device also
|
||||
// requires us to reset those values.
|
||||
static void SetDevice(const int device_id);
|
||||
// Prints the current GPU status.
|
||||
static void DeviceQuery();
|
||||
|
||||
|
|
|
@ -10,11 +10,11 @@ namespace caffe {
|
|||
template <typename Dtype>
|
||||
class Solver {
|
||||
public:
|
||||
explicit Solver(const SolverParameter& param)
|
||||
: param_(param) {}
|
||||
explicit Solver(const SolverParameter& param);
|
||||
// The main entry of the solver function. In default, iter will be zero. Pass
|
||||
// in a non-zero iter number to resume training for a pre-trained net.
|
||||
void Solve(Net<Dtype>* net, char* state_file = NULL);
|
||||
void Solve(const char* resume_file = NULL);
|
||||
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
|
||||
virtual ~Solver() {}
|
||||
|
||||
protected:
|
||||
|
@ -28,15 +28,18 @@ class Solver {
|
|||
// function that produces a SolverState protocol buffer that needs to be
|
||||
// written to disk together with the learned net.
|
||||
void Snapshot();
|
||||
// The test routine
|
||||
void Test();
|
||||
virtual void SnapshotSolverState(SolverState* state) = 0;
|
||||
// The Restore function implements how one should restore the solver to a
|
||||
// previously snapshotted state. You should implement the RestoreSolverState()
|
||||
// function that restores the state from a SolverState protocol buffer.
|
||||
void Restore(char* state_file);
|
||||
void Restore(const char* resume_file);
|
||||
virtual void RestoreSolverState(const SolverState& state) = 0;
|
||||
SolverParameter param_;
|
||||
int iter_;
|
||||
Net<Dtype>* net_;
|
||||
Net<Dtype>* test_net_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(Solver);
|
||||
};
|
||||
|
|
|
@ -74,6 +74,24 @@ void Caffe::set_random_seed(const unsigned int seed) {
|
|||
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
|
||||
}
|
||||
|
||||
void Caffe::SetDevice(const int device_id) {
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
if (current_device == device_id) {
|
||||
return;
|
||||
}
|
||||
if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
|
||||
if (Get().curand_generator_) {
|
||||
CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
|
||||
}
|
||||
CUDA_CHECK(cudaSetDevice(device_id));
|
||||
CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
|
||||
CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
|
||||
CURAND_RNG_PSEUDO_DEFAULT));
|
||||
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
|
||||
time(NULL)));
|
||||
}
|
||||
|
||||
void Caffe::DeviceQuery() {
|
||||
cudaDeviceProp prop;
|
||||
int device;
|
||||
|
|
|
@ -101,7 +101,7 @@ void* DataLayerPrefetch(void* layer_pointer) {
|
|||
layer->iter_->Next();
|
||||
if (!layer->iter_->Valid()) {
|
||||
// We have reached the end. Restart from the first.
|
||||
LOG(INFO) << "Restarting data prefetching from start.";
|
||||
DLOG(INFO) << "Restarting data prefetching from start.";
|
||||
layer->iter_->SeekToFirst();
|
||||
}
|
||||
}
|
||||
|
@ -180,10 +180,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
|
|||
prefetch_data_->mutable_cpu_data();
|
||||
prefetch_label_->mutable_cpu_data();
|
||||
data_mean_.cpu_data();
|
||||
// LOG(INFO) << "Initializing prefetch";
|
||||
DLOG(INFO) << "Initializing prefetch";
|
||||
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
|
||||
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
|
||||
// LOG(INFO) << "Prefetch initialized.";
|
||||
DLOG(INFO) << "Prefetch initialized.";
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
|
|
|
@ -121,7 +121,7 @@ Net<Dtype>::Net(const NetParameter& param,
|
|||
// In the end, all remaining blobs are considered output blobs.
|
||||
for (set<string>::iterator it = available_blobs.begin();
|
||||
it != available_blobs.end(); ++it) {
|
||||
LOG(ERROR) << "This network produces output " << *it;
|
||||
LOG(INFO) << "This network produces output " << *it;
|
||||
net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
|
||||
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
|
||||
}
|
||||
|
@ -207,10 +207,10 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
|
|||
++target_layer_id;
|
||||
}
|
||||
if (target_layer_id == layer_names_.size()) {
|
||||
LOG(INFO) << "Ignoring source layer " << source_layer_name;
|
||||
DLOG(INFO) << "Ignoring source layer " << source_layer_name;
|
||||
continue;
|
||||
}
|
||||
LOG(INFO) << "Loading source layer " << source_layer_name;
|
||||
DLOG(INFO) << "Loading source layer " << source_layer_name;
|
||||
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
|
||||
layers_[target_layer_id]->blobs();
|
||||
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
|
||||
|
@ -233,7 +233,7 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
|
|||
for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
|
||||
param->add_input(blob_names_[net_input_blob_indices_[i]]);
|
||||
}
|
||||
LOG(INFO) << "Serializing " << layers_.size() << " layers";
|
||||
DLOG(INFO) << "Serializing " << layers_.size() << " layers";
|
||||
for (int i = 0; i < layers_.size(); ++i) {
|
||||
LayerConnection* layer_connection = param->add_layers();
|
||||
for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {
|
||||
|
|
|
@ -94,25 +94,28 @@ message NetParameter {
|
|||
}
|
||||
|
||||
message SolverParameter {
|
||||
optional float base_lr = 1; // The base learning rate
|
||||
optional string train_net = 1; // The proto file for the training net.
|
||||
optional string test_net = 2; // The proto file for the testing net.
|
||||
// The number of iterations for each testing phase.
|
||||
optional int32 test_iter = 3 [ default = 0 ];
|
||||
// The number of iterations between two testing phases.
|
||||
optional int32 test_interval = 4 [ default = 0 ];
|
||||
optional float base_lr = 5; // The base learning rate
|
||||
// the number of iterations between displaying info. If display = 0, no info
|
||||
// will be displayed.
|
||||
optional int32 display = 2;
|
||||
optional int32 max_iter = 3; // the maximum number of iterations
|
||||
optional int32 snapshot = 4 [default = 0]; // The snapshot interval
|
||||
optional string lr_policy = 5; // The learning rate decay policy.
|
||||
optional float min_lr = 6 [default = 0]; // The mininum learning rate
|
||||
optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
|
||||
optional float gamma = 8; // The parameter to compute the learning rate.
|
||||
optional float power = 9; // The parameter to compute the learning rate.
|
||||
optional float momentum = 10; // The momentum value.
|
||||
optional float weight_decay = 11; // The weight decay.
|
||||
optional int32 stepsize = 12; // the stepsize for learning rate policy "step"
|
||||
|
||||
optional string snapshot_prefix = 13; // The prefix for the snapshot.
|
||||
optional int32 display = 6;
|
||||
optional int32 max_iter = 7; // the maximum number of iterations
|
||||
optional string lr_policy = 8; // The learning rate decay policy.
|
||||
optional float gamma = 9; // The parameter to compute the learning rate.
|
||||
optional float power = 10; // The parameter to compute the learning rate.
|
||||
optional float momentum = 11; // The momentum value.
|
||||
optional float weight_decay = 12; // The weight decay.
|
||||
optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
|
||||
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
|
||||
optional string snapshot_prefix = 15; // The prefix for the snapshot.
|
||||
// whether to snapshot diff in the results or not. Snapshotting diff will help
|
||||
// debugging but the final protocol buffer size will be much larger.
|
||||
optional bool snapshot_diff = 14 [ default = false];
|
||||
optional bool snapshot_diff = 16 [ default = false];
|
||||
}
|
||||
|
||||
// A message that stores the solver snapshots
|
||||
|
|
|
@ -18,8 +18,31 @@ using std::min;
|
|||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
|
||||
net_ = net;
|
||||
Solver<Dtype>::Solver(const SolverParameter& param)
|
||||
: param_(param), net_(NULL), test_net_(NULL) {
|
||||
// Scaffolding code
|
||||
NetParameter train_net_param;
|
||||
ReadProtoFromTextFile(param_.train_net(), &train_net_param);
|
||||
// For the training network, there should be no input - so we simply create
|
||||
// a dummy bottom_vec instance to initialize the networks.
|
||||
vector<Blob<Dtype>*> bottom_vec;
|
||||
LOG(INFO) << "Creating training net.";
|
||||
net_ = new Net<Dtype>(train_net_param, bottom_vec);
|
||||
if (param_.has_test_net()) {
|
||||
LOG(INFO) << "Creating testing net.";
|
||||
NetParameter test_net_param;
|
||||
ReadProtoFromTextFile(param_.test_net(), &test_net_param);
|
||||
test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
|
||||
CHECK_GT(param_.test_iter(), 0);
|
||||
CHECK_GT(param_.test_interval(), 0);
|
||||
}
|
||||
LOG(INFO) << "Solver scaffolding done.";
|
||||
}
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
void Solver<Dtype>::Solve(const char* resume_file) {
|
||||
Caffe::set_phase(Caffe::TRAIN);
|
||||
LOG(INFO) << "Solving " << net_->name();
|
||||
PreSolve();
|
||||
|
||||
|
@ -42,13 +65,54 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
|
|||
Snapshot();
|
||||
}
|
||||
if (param_.display() && iter_ % param_.display() == 0) {
|
||||
LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
|
||||
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
|
||||
}
|
||||
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
|
||||
// We need to set phase to test before running.
|
||||
Caffe::set_phase(Caffe::TEST);
|
||||
Test();
|
||||
Caffe::set_phase(Caffe::TRAIN);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Optimization Done.";
|
||||
}
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
void Solver<Dtype>::Test() {
|
||||
LOG(INFO) << "Testing net";
|
||||
NetParameter net_param;
|
||||
net_->ToProto(&net_param);
|
||||
CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
|
||||
vector<Dtype> test_score;
|
||||
vector<Blob<Dtype>*> bottom_vec;
|
||||
for (int i = 0; i < param_.test_iter(); ++i) {
|
||||
const vector<Blob<Dtype>*>& result =
|
||||
test_net_->Forward(bottom_vec);
|
||||
if (i == 0) {
|
||||
for (int j = 0; j < result.size(); ++j) {
|
||||
const Dtype* result_vec = result[j]->cpu_data();
|
||||
for (int k = 0; k < result[j]->count(); ++k) {
|
||||
test_score.push_back(result_vec[k]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int idx = 0;
|
||||
for (int j = 0; j < result.size(); ++j) {
|
||||
const Dtype* result_vec = result[j]->cpu_data();
|
||||
for (int k = 0; k < result[j]->count(); ++k) {
|
||||
test_score[idx++] += result_vec[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < test_score.size(); ++i) {
|
||||
LOG(INFO) << "Test score #" << i << ": "
|
||||
<< test_score[i] / param_.test_iter();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
void Solver<Dtype>::Snapshot() {
|
||||
NetParameter net_param;
|
||||
|
@ -70,7 +134,7 @@ void Solver<Dtype>::Snapshot() {
|
|||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void Solver<Dtype>::Restore(char* state_file) {
|
||||
void Solver<Dtype>::Restore(const char* state_file) {
|
||||
SolverState state;
|
||||
NetParameter net_param;
|
||||
ReadProtoFromBinaryFile(state_file, &state);
|
||||
|
@ -108,8 +172,6 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
|
|||
} else {
|
||||
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
|
||||
}
|
||||
rate = min(max(rate, Dtype(this->param_.min_lr())),
|
||||
Dtype(this->param_.max_lr()));
|
||||
return rate;
|
||||
}
|
||||
|
||||
|
@ -136,7 +198,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
|
|||
// get the learning rate
|
||||
Dtype rate = GetLearningRate();
|
||||
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
|
||||
LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
|
||||
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
|
||||
}
|
||||
Dtype momentum = this->param_.momentum();
|
||||
Dtype weight_decay = this->param_.weight_decay();
|
||||
|
|
Загрузка…
Ссылка в новой задаче