allow setting custom weight decay

This commit is contained in:
Yangqing Jia 2013-10-22 13:18:50 -07:00
Родитель 8046bcf5fd
Коммит 62089dd8da
5 изменённых файлов: 50 добавлений и 14 удалений

Просмотреть файл

@ -5,7 +5,7 @@ layers {
type: "data"
source: "/home/jiayq/caffe-train-leveldb"
meanfile: "/home/jiayq/ilsvrc2012_mean.binaryproto"
batchsize: 64
batchsize: 256
cropsize: 227
mirror: true
}
@ -25,10 +25,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 0.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "data"
top: "conv1"
@ -85,10 +87,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 1.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "pad2"
top: "conv2"
@ -144,10 +148,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 0.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "pad3"
top: "conv3"
@ -182,10 +188,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 1.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "pad4"
top: "conv4"
@ -220,10 +228,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 1.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "pad5"
top: "conv5"
@ -258,10 +268,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 1.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "pool5"
top: "fc6"
@ -294,10 +306,12 @@ layers {
}
bias_filler {
type: "constant"
value: 0.1
value: 1.
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "fc6"
top: "fc7"
@ -334,6 +348,8 @@ layers {
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
}
bottom: "fc7"
top: "fc8"

Просмотреть файл

@ -60,6 +60,7 @@ class Net {
inline vector<shared_ptr<Blob<Dtype> > >& params() { return params_; }
// returns the parameter learning rate multipliers
inline vector<float>& params_lr() {return params_lr_; }
inline vector<float>& params_weight_decay() { return params_weight_decay_; }
// Updates the network
void Update();
@ -86,6 +87,8 @@ class Net {
vector<shared_ptr<Blob<Dtype> > > params_;
// the learning rate multipliers
vector<float> params_lr_;
// the weight decay multipliers
vector<float> params_weight_decay_;
DISABLE_COPY_AND_ASSIGN(Net);
};

Просмотреть файл

@ -119,6 +119,20 @@ Net<Dtype>::Net(const NetParameter& param,
params_lr_.push_back(1.);
}
}
// push the weight decay multipliers
if (layers_[i]->layer_param().weight_decay_size()) {
CHECK_EQ(layers_[i]->layer_param().weight_decay_size(),
layer_blobs.size());
for (int j = 0; j < layer_blobs.size(); ++j) {
float local_decay = layers_[i]->layer_param().weight_decay(j);
CHECK_GT(local_decay, 0.);
params_weight_decay_.push_back(local_decay);
}
} else {
for (int j = 0; j < layer_blobs.size(); ++j) {
params_weight_decay_.push_back(1.);
}
}
for (int topid = 0; topid < top_vecs_[i].size(); ++topid) {
LOG(INFO) << "Top shape: " << top_vecs_[i][topid]->channels() << " "
<< top_vecs_[i][topid]->height() << " "

Просмотреть файл

@ -76,6 +76,8 @@ message LayerParameter {
// The ratio that is multiplied on the global learning rate. If you want to set
// the learning ratio for one blob, you need to set it for all blobs.
repeated float blobs_lr = 51;
// The weight decay that is multiplied on the global weight decay.
repeated float weight_decay = 52;
}
message LayerConnection {

Просмотреть файл

@ -132,6 +132,7 @@ template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue() {
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
vector<float>& net_params_lr = this->net_->params_lr();
vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
// get the learning rate
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
@ -139,20 +140,19 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
// LOG(ERROR) << "rate:" << rate << " momentum:" << momentum
// << " weight_decay:" << weight_decay;
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
caffe_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
if (weight_decay) {
if (local_decay) {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
weight_decay * local_rate,
local_decay * local_rate,
net_params[param_id]->cpu_data(),
history_[param_id]->mutable_cpu_data());
}
@ -166,13 +166,14 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(), momentum,
history_[param_id]->mutable_gpu_data());
if (weight_decay) {
if (local_decay) {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
weight_decay * local_rate,
local_decay * local_rate,
net_params[param_id]->gpu_data(),
history_[param_id]->mutable_gpu_data());
}