зеркало из https://github.com/microsoft/caffe.git
Change solver type to string and provide solver registry
This commit is contained in:
Родитель
b822a702d1
Коммит
0eea815ad6
|
@ -13,6 +13,7 @@
|
|||
#include "caffe/parallel.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
#include "caffe/solver.hpp"
|
||||
#include "caffe/solver_factory.hpp"
|
||||
#include "caffe/util/benchmark.hpp"
|
||||
#include "caffe/util/io.hpp"
|
||||
#include "caffe/vision_layers.hpp"
|
||||
|
|
|
@ -19,6 +19,7 @@ class SGDSolver : public Solver<Dtype> {
|
|||
: Solver<Dtype>(param) { PreSolve(); }
|
||||
explicit SGDSolver(const string& param_file)
|
||||
: Solver<Dtype>(param_file) { PreSolve(); }
|
||||
virtual inline const char* type() const { return "SGD"; }
|
||||
|
||||
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
|
||||
|
||||
|
@ -51,6 +52,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
|
|||
: SGDSolver<Dtype>(param) {}
|
||||
explicit NesterovSolver(const string& param_file)
|
||||
: SGDSolver<Dtype>(param_file) {}
|
||||
virtual inline const char* type() const { return "Nesterov"; }
|
||||
|
||||
protected:
|
||||
virtual void ComputeUpdateValue(int param_id, Dtype rate);
|
||||
|
@ -65,6 +67,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
|
|||
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
|
||||
explicit AdaGradSolver(const string& param_file)
|
||||
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
|
||||
virtual inline const char* type() const { return "AdaGrad"; }
|
||||
|
||||
protected:
|
||||
virtual void ComputeUpdateValue(int param_id, Dtype rate);
|
||||
|
@ -84,6 +87,7 @@ class RMSPropSolver : public SGDSolver<Dtype> {
|
|||
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
|
||||
explicit RMSPropSolver(const string& param_file)
|
||||
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
|
||||
virtual inline const char* type() const { return "RMSProp"; }
|
||||
|
||||
protected:
|
||||
virtual void ComputeUpdateValue(int param_id, Dtype rate);
|
||||
|
@ -106,6 +110,7 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
|
|||
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
|
||||
explicit AdaDeltaSolver(const string& param_file)
|
||||
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
|
||||
virtual inline const char* type() const { return "AdaDelta"; }
|
||||
|
||||
protected:
|
||||
void AdaDeltaPreSolve();
|
||||
|
@ -129,6 +134,7 @@ class AdamSolver : public SGDSolver<Dtype> {
|
|||
: SGDSolver<Dtype>(param) { AdamPreSolve();}
|
||||
explicit AdamSolver(const string& param_file)
|
||||
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
|
||||
virtual inline const char* type() const { return "Adam"; }
|
||||
|
||||
protected:
|
||||
void AdamPreSolve();
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "caffe/net.hpp"
|
||||
#include "caffe/solver_factory.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
|
@ -83,6 +84,10 @@ class Solver {
|
|||
}
|
||||
|
||||
void CheckSnapshotWritePermissions();
|
||||
/**
|
||||
* @brief Returns the solver type.
|
||||
*/
|
||||
virtual inline const char* type() const { return ""; }
|
||||
|
||||
protected:
|
||||
// Make and apply the update value for the current iteration.
|
||||
|
@ -148,10 +153,6 @@ class WorkerSolver : public Solver<Dtype> {
|
|||
}
|
||||
};
|
||||
|
||||
// The solver factory function
|
||||
template <typename Dtype>
|
||||
Solver<Dtype>* GetSolver(const SolverParameter& param);
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif // CAFFE_SOLVER_HPP_
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* @brief A solver factory that allows one to register solvers, similar to
|
||||
* layer factory. During runtime, registered solvers could be called by passing
|
||||
* a SolverParameter protobuffer to the CreateSolver function:
|
||||
*
|
||||
* SolverRegistry<Dtype>::CreateSolver(param);
|
||||
*
|
||||
* There are two ways to register a solver. Assuming that we have a solver like:
|
||||
*
|
||||
* template <typename Dtype>
|
||||
* class MyAwesomeSolver : public Solver<Dtype> {
|
||||
* // your implementations
|
||||
* };
|
||||
*
|
||||
* and its type is its C++ class name, but without the "Solver" at the end
|
||||
* ("MyAwesomeSolver" -> "MyAwesome").
|
||||
*
|
||||
* If the solver is going to be created simply by its constructor, in your c++
|
||||
* file, add the following line:
|
||||
*
|
||||
* REGISTER_SOLVER_CLASS(MyAwesome);
|
||||
*
|
||||
* Or, if the solver is going to be created by another creator function, in the
|
||||
* format of:
|
||||
*
|
||||
* template <typename Dtype>
|
||||
* Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
|
||||
* // your implementation
|
||||
* }
|
||||
*
|
||||
* then you can register the creator function instead, like
|
||||
*
|
||||
* REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
|
||||
*
|
||||
* Note that each solver type should only be registered once.
|
||||
*/
|
||||
|
||||
#ifndef CAFFE_SOLVER_FACTORY_H_
|
||||
#define CAFFE_SOLVER_FACTORY_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
class Solver;
|
||||
|
||||
template <typename Dtype>
|
||||
class SolverRegistry {
|
||||
public:
|
||||
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
|
||||
typedef std::map<string, Creator> CreatorRegistry;
|
||||
|
||||
static CreatorRegistry& Registry() {
|
||||
static CreatorRegistry* g_registry_ = new CreatorRegistry();
|
||||
return *g_registry_;
|
||||
}
|
||||
|
||||
// Adds a creator.
|
||||
static void AddCreator(const string& type, Creator creator) {
|
||||
CreatorRegistry& registry = Registry();
|
||||
CHECK_EQ(registry.count(type), 0)
|
||||
<< "Solver type " << type << " already registered.";
|
||||
registry[type] = creator;
|
||||
}
|
||||
|
||||
// Get a solver using a SolverParameter.
|
||||
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
|
||||
const string& type = param.type();
|
||||
CreatorRegistry& registry = Registry();
|
||||
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
|
||||
<< " (known types: " << SolverTypeListString() << ")";
|
||||
return registry[type](param);
|
||||
}
|
||||
|
||||
static vector<string> SolverTypeList() {
|
||||
CreatorRegistry& registry = Registry();
|
||||
vector<string> solver_types;
|
||||
for (typename CreatorRegistry::iterator iter = registry.begin();
|
||||
iter != registry.end(); ++iter) {
|
||||
solver_types.push_back(iter->first);
|
||||
}
|
||||
return solver_types;
|
||||
}
|
||||
|
||||
private:
|
||||
// Solver registry should never be instantiated - everything is done with its
|
||||
// static variables.
|
||||
SolverRegistry() {}
|
||||
|
||||
static string SolverTypeListString() {
|
||||
vector<string> solver_types = SolverTypeList();
|
||||
string solver_types_str;
|
||||
for (vector<string>::iterator iter = solver_types.begin();
|
||||
iter != solver_types.end(); ++iter) {
|
||||
if (iter != solver_types.begin()) {
|
||||
solver_types_str += ", ";
|
||||
}
|
||||
solver_types_str += *iter;
|
||||
}
|
||||
return solver_types_str;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
class SolverRegisterer {
|
||||
public:
|
||||
SolverRegisterer(const string& type,
|
||||
Solver<Dtype>* (*creator)(const SolverParameter&)) {
|
||||
// LOG(INFO) << "Registering solver type: " << type;
|
||||
SolverRegistry<Dtype>::AddCreator(type, creator);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#define REGISTER_SOLVER_CREATOR(type, creator) \
|
||||
static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
|
||||
static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
|
||||
|
||||
#define REGISTER_SOLVER_CLASS(type) \
|
||||
template <typename Dtype> \
|
||||
Solver<Dtype>* Creator_##type##Solver( \
|
||||
const SolverParameter& param) \
|
||||
{ \
|
||||
return new type##Solver<Dtype>(param); \
|
||||
} \
|
||||
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif // CAFFE_SOLVER_FACTORY_H_
|
|
@ -98,7 +98,7 @@ message NetParameter {
|
|||
// NOTE
|
||||
// Update the next available ID when you add a new SolverParameter field.
|
||||
//
|
||||
// SolverParameter next available ID: 40 (last added: momentum2)
|
||||
// SolverParameter next available ID: 41 (last added: type)
|
||||
message SolverParameter {
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Specifying the train and test networks
|
||||
|
@ -209,16 +209,9 @@ message SolverParameter {
|
|||
// (and by default) initialize using a seed derived from the system clock.
|
||||
optional int64 random_seed = 20 [default = -1];
|
||||
|
||||
// Solver type
|
||||
enum SolverType {
|
||||
SGD = 0;
|
||||
NESTEROV = 1;
|
||||
ADAGRAD = 2;
|
||||
RMSPROP = 3;
|
||||
ADADELTA = 4;
|
||||
ADAM = 5;
|
||||
}
|
||||
optional SolverType solver_type = 30 [default = SGD];
|
||||
// type of the solver
|
||||
optional string type = 40 [default = "SGD"];
|
||||
|
||||
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
|
||||
optional float delta = 31 [default = 1e-8];
|
||||
// parameters for the Adam solver
|
||||
|
@ -234,6 +227,18 @@ message SolverParameter {
|
|||
|
||||
// If false, don't save a snapshot after training finishes.
|
||||
optional bool snapshot_after_train = 28 [default = true];
|
||||
|
||||
// DEPRECATED: old solver enum types, use string instead
|
||||
enum SolverType {
|
||||
SGD = 0;
|
||||
NESTEROV = 1;
|
||||
ADAGRAD = 2;
|
||||
RMSPROP = 3;
|
||||
ADADELTA = 4;
|
||||
ADAM = 5;
|
||||
}
|
||||
// DEPRECATED: use type instead of solver_type
|
||||
optional SolverType solver_type = 30 [default = SGD];
|
||||
}
|
||||
|
||||
// A message that stores the solver snapshots
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
#include "caffe/solver.hpp"
|
||||
#include "caffe/sgd_solvers.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
Solver<Dtype>* GetSolver(const SolverParameter& param) {
|
||||
SolverParameter_SolverType type = param.solver_type();
|
||||
|
||||
switch (type) {
|
||||
case SolverParameter_SolverType_SGD:
|
||||
return new SGDSolver<Dtype>(param);
|
||||
case SolverParameter_SolverType_NESTEROV:
|
||||
return new NesterovSolver<Dtype>(param);
|
||||
case SolverParameter_SolverType_ADAGRAD:
|
||||
return new AdaGradSolver<Dtype>(param);
|
||||
case SolverParameter_SolverType_RMSPROP:
|
||||
return new RMSPropSolver<Dtype>(param);
|
||||
case SolverParameter_SolverType_ADADELTA:
|
||||
return new AdaDeltaSolver<Dtype>(param);
|
||||
case SolverParameter_SolverType_ADAM:
|
||||
return new AdamSolver<Dtype>(param);
|
||||
default:
|
||||
LOG(FATAL) << "Unknown SolverType: " << type;
|
||||
}
|
||||
return (Solver<Dtype>*) NULL;
|
||||
}
|
||||
|
||||
template Solver<float>* GetSolver(const SolverParameter& param);
|
||||
template Solver<double>* GetSolver(const SolverParameter& param);
|
||||
|
||||
} // namespace caffe
|
|
@ -151,5 +151,6 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
|
|||
}
|
||||
|
||||
INSTANTIATE_CLASS(AdaDeltaSolver);
|
||||
REGISTER_SOLVER_CLASS(AdaDelta);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -84,5 +84,6 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
|
|||
}
|
||||
|
||||
INSTANTIATE_CLASS(AdaGradSolver);
|
||||
REGISTER_SOLVER_CLASS(AdaGrad);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -108,5 +108,6 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
|
|||
}
|
||||
|
||||
INSTANTIATE_CLASS(AdamSolver);
|
||||
REGISTER_SOLVER_CLASS(Adam);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -66,5 +66,6 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
|
|||
}
|
||||
|
||||
INSTANTIATE_CLASS(NesterovSolver);
|
||||
REGISTER_SOLVER_CLASS(Nesterov);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -80,5 +80,6 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
|
|||
}
|
||||
|
||||
INSTANTIATE_CLASS(RMSPropSolver);
|
||||
REGISTER_SOLVER_CLASS(RMSProp);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -343,5 +343,6 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
|
|||
}
|
||||
|
||||
INSTANTIATE_CLASS(SGDSolver);
|
||||
REGISTER_SOLVER_CLASS(SGD);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -47,7 +47,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
|
|||
// Test data: check out generate_sample_data.py in the same directory.
|
||||
string* input_file_;
|
||||
|
||||
virtual SolverParameter_SolverType solver_type() = 0;
|
||||
virtual void InitSolver(const SolverParameter& param) = 0;
|
||||
|
||||
virtual void InitSolverFromProtoString(const string& proto) {
|
||||
|
@ -290,8 +289,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
|
|||
((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
|
||||
// Finally, compute update.
|
||||
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
|
||||
if (solver_type() != SolverParameter_SolverType_ADADELTA
|
||||
&& solver_type() != SolverParameter_SolverType_ADAM) {
|
||||
if (solver_->type() != string("AdaDelta")
|
||||
&& solver_->type() != string("Adam")) {
|
||||
ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias
|
||||
} else {
|
||||
ASSERT_EQ(4, history.size()); // additional blobs for update history
|
||||
|
@ -300,26 +299,19 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
|
|||
const Dtype history_value = (i == D) ?
|
||||
history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
|
||||
const Dtype temp = momentum * history_value;
|
||||
switch (solver_type()) {
|
||||
case SolverParameter_SolverType_SGD:
|
||||
if (solver_->type() == string("SGD")) {
|
||||
update_value += temp;
|
||||
break;
|
||||
case SolverParameter_SolverType_NESTEROV:
|
||||
} else if (solver_->type() == string("Nesterov")) {
|
||||
update_value += temp;
|
||||
// step back then over-step
|
||||
update_value = (1 + momentum) * update_value - temp;
|
||||
break;
|
||||
case SolverParameter_SolverType_ADAGRAD:
|
||||
} else if (solver_->type() == string("AdaGrad")) {
|
||||
update_value /= std::sqrt(history_value + grad * grad) + delta_;
|
||||
break;
|
||||
case SolverParameter_SolverType_RMSPROP: {
|
||||
} else if (solver_->type() == string("RMSProp")) {
|
||||
const Dtype rms_decay = 0.95;
|
||||
update_value /= std::sqrt(rms_decay*history_value
|
||||
+ grad * grad * (1 - rms_decay)) + delta_;
|
||||
}
|
||||
break;
|
||||
case SolverParameter_SolverType_ADADELTA:
|
||||
{
|
||||
} else if (solver_->type() == string("AdaDelta")) {
|
||||
const Dtype update_history_value = (i == D) ?
|
||||
history[1 + num_param_blobs]->cpu_data()[0] :
|
||||
history[0 + num_param_blobs]->cpu_data()[i];
|
||||
|
@ -330,9 +322,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
|
|||
// not actually needed, just here for illustrative purposes
|
||||
// const Dtype weighted_update_average =
|
||||
// momentum * update_history_value + (1 - momentum) * (update_value);
|
||||
break;
|
||||
}
|
||||
case SolverParameter_SolverType_ADAM: {
|
||||
} else if (solver_->type() == string("Adam")) {
|
||||
const Dtype momentum2 = 0.999;
|
||||
const Dtype m = history_value;
|
||||
const Dtype v = (i == D) ?
|
||||
|
@ -344,10 +334,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
|
|||
std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
|
||||
(Dtype(1.) - pow(momentum, num_iters));
|
||||
update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown solver type: " << solver_type();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown solver type: " << solver_->type();
|
||||
}
|
||||
if (i == D) {
|
||||
updated_bias.mutable_cpu_diff()[0] = update_value;
|
||||
|
@ -392,7 +380,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
|
|||
EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
|
||||
|
||||
// Check the solver's history -- should contain the previous update value.
|
||||
if (solver_type() == SolverParameter_SolverType_SGD) {
|
||||
if (solver_->type() == string("SGD")) {
|
||||
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
|
||||
ASSERT_EQ(2, history.size());
|
||||
for (int i = 0; i < D; ++i) {
|
||||
|
@ -581,10 +569,6 @@ class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {
|
|||
virtual void InitSolver(const SolverParameter& param) {
|
||||
this->solver_.reset(new SGDSolver<Dtype>(param));
|
||||
}
|
||||
|
||||
virtual SolverParameter_SolverType solver_type() {
|
||||
return SolverParameter_SolverType_SGD;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
|
||||
|
@ -721,9 +705,6 @@ class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
|
|||
virtual void InitSolver(const SolverParameter& param) {
|
||||
this->solver_.reset(new AdaGradSolver<Dtype>(param));
|
||||
}
|
||||
virtual SolverParameter_SolverType solver_type() {
|
||||
return SolverParameter_SolverType_ADAGRAD;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
|
||||
|
@ -824,9 +805,6 @@ class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
|
|||
virtual void InitSolver(const SolverParameter& param) {
|
||||
this->solver_.reset(new NesterovSolver<Dtype>(param));
|
||||
}
|
||||
virtual SolverParameter_SolverType solver_type() {
|
||||
return SolverParameter_SolverType_NESTEROV;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
|
||||
|
@ -960,10 +938,6 @@ class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> {
|
|||
virtual void InitSolver(const SolverParameter& param) {
|
||||
this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
|
||||
}
|
||||
|
||||
virtual SolverParameter_SolverType solver_type() {
|
||||
return SolverParameter_SolverType_ADADELTA;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
|
||||
|
@ -1098,9 +1072,6 @@ class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
|
|||
new_param.set_momentum2(momentum2);
|
||||
this->solver_.reset(new AdamSolver<Dtype>(new_param));
|
||||
}
|
||||
virtual SolverParameter_SolverType solver_type() {
|
||||
return SolverParameter_SolverType_ADAM;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
|
||||
|
@ -1201,9 +1172,6 @@ class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
|
|||
new_param.set_rms_decay(rms_decay);
|
||||
this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
|
||||
}
|
||||
virtual SolverParameter_SolverType solver_type() {
|
||||
return SolverParameter_SolverType_RMSPROP;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "boost/scoped_ptr.hpp"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/solver.hpp"
|
||||
#include "caffe/solver_factory.hpp"
|
||||
|
||||
#include "caffe/test/test_caffe_main.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename TypeParam>
|
||||
class SolverFactoryTest : public MultiDeviceTest<TypeParam> {
|
||||
protected:
|
||||
SolverParameter simple_solver_param() {
|
||||
const string solver_proto =
|
||||
"train_net_param { "
|
||||
" layer { "
|
||||
" name: 'data' type: 'DummyData' top: 'data' "
|
||||
" dummy_data_param { shape { dim: 1 } } "
|
||||
" } "
|
||||
"} ";
|
||||
SolverParameter solver_param;
|
||||
CHECK(google::protobuf::TextFormat::ParseFromString(
|
||||
solver_proto, &solver_param));
|
||||
return solver_param;
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(SolverFactoryTest, TestDtypesAndDevices);
|
||||
|
||||
TYPED_TEST(SolverFactoryTest, TestCreateSolver) {
|
||||
typedef typename TypeParam::Dtype Dtype;
|
||||
typename SolverRegistry<Dtype>::CreatorRegistry& registry =
|
||||
SolverRegistry<Dtype>::Registry();
|
||||
shared_ptr<Solver<Dtype> > solver;
|
||||
SolverParameter solver_param = this->simple_solver_param();
|
||||
for (typename SolverRegistry<Dtype>::CreatorRegistry::iterator iter =
|
||||
registry.begin(); iter != registry.end(); ++iter) {
|
||||
solver_param.set_type(iter->first);
|
||||
solver.reset(SolverRegistry<Dtype>::CreateSolver(solver_param));
|
||||
EXPECT_EQ(iter->first, solver->type());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe
|
|
@ -194,7 +194,7 @@ int train() {
|
|||
GetRequestedAction(FLAGS_sighup_effect));
|
||||
|
||||
shared_ptr<caffe::Solver<float> >
|
||||
solver(caffe::GetSolver<float>(solver_param));
|
||||
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
|
||||
|
||||
solver->SetActionFunction(signal_handler.GetActionFunction());
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче