Adding nnet-modify-learning-rates.cc

git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/dan2@3239 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Guoguo Chen 2013-11-29 06:40:40 +00:00
Родитель 1d3cc77dbb
Коммит 506c9faf2c
2 изменённых файлов: 173 добавлений и 1 удалений

Просмотреть файл

@ -22,7 +22,8 @@ BINFILES = nnet-randomize-frames nnet-am-info nnet-init \
nnet-get-feature-transform-multi nnet-copy-egs-discriminative \
nnet-get-egs-discriminative nnet-shuffle-egs-discriminative \
nnet-compare-hash-discriminative nnet-combine-egs-discriminative \
nnet-train-discriminative-simple nnet-train-discriminative-parallel
nnet-train-discriminative-simple nnet-train-discriminative-parallel \
nnet-modify-learning-rates
OBJFILES =

Просмотреть файл

@ -0,0 +1,171 @@
// nnet2bin/nnet-modify-learning-rates.cc
// Copyright 2013 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "hmm/transition-model.h"
#include "nnet2/nnet-randomize.h"
#include "nnet2/train-nnet.h"
#include "nnet2/am-nnet.h"
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using namespace kaldi::nnet2;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
const char *usage =
"This program modifies the learning rates so as to equalize the\n"
"relative changes in parameters for each layer, while keeping their\n"
"geometric mean the same (or changing it to a value specified using\n"
"the --average-learning-rate option)."
"\n"
"Usage: nnet-modify-learning-rates [options] <prev-model> \\\n"
" <cur-model> <modified-cur-model>\n"
"e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n"
" 5.mdl 6.mdl 6.mdl\n";
bool binary_write = true;
BaseFloat average_learning_rate = 0.0;
std::string use_gpu = "optional";
ParseOptions po(usage);
po.Register("binary", &binary_write, "Write output in binary mode");
po.Register("average-learning-rate", &average_learning_rate,
"If supplied, change learning rate geometric mean to the given "
"value.");
po.Register("use-gpu", &use_gpu,
"yes|no|optional, only has effect if compiled with CUDA");
po.Read(argc, argv);
if (po.NumArgs() != 3) {
po.PrintUsage();
exit(1);
}
KALDI_ASSERT(average_learning_rate >= 0);
#if HAVE_CUDA==1
CuDevice::Instantiate().SelectGpuId(use_gpu);
#endif
std::string prev_nnet_rxfilename = po.GetArg(1),
cur_nnet_rxfilename = po.GetArg(2),
modified_cur_nnet_rxfilename = po.GetOptArg(3);
TransitionModel trans_model;
AmNnet am_prev_nnet, am_cur_nnet;
{
bool binary_read;
Input ki(prev_nnet_rxfilename, &binary_read);
trans_model.Read(ki.Stream(), binary_read);
am_prev_nnet.Read(ki.Stream(), binary_read);
}
{
bool binary_read;
Input ki(cur_nnet_rxfilename, &binary_read);
trans_model.Read(ki.Stream(), binary_read);
am_cur_nnet.Read(ki.Stream(), binary_read);
}
if (am_prev_nnet.GetNnet().GetParameterDim() !=
am_cur_nnet.GetNnet().GetParameterDim()) {
KALDI_WARN << "Parameter-dim mismatch, cannot equalize the relative "
<< "changes in parameters for each layer.";
exit(0);
}
int32 ret = 0;
// Gets relative parameter differences.
int32 num_updatable = am_prev_nnet.GetNnet().NumUpdatableComponents();
Vector<BaseFloat> relative_diff(num_updatable);
{
Nnet diff_nnet(am_prev_nnet.GetNnet());
diff_nnet.AddNnet(-1.0, am_cur_nnet.GetNnet());
diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff);
relative_diff.ApplyPow(0.5);
Vector<BaseFloat> baseline_prod(num_updatable);
am_prev_nnet.GetNnet().ComponentDotProducts(am_prev_nnet.GetNnet(),
&baseline_prod);
baseline_prod.ApplyPow(0.5);
relative_diff.DivElements(baseline_prod);
KALDI_LOG << "Relative parameter differences per layer are "
<< relative_diff;
// If relative parameter difference for a certain is zero, set it to the
// mean of the rest values.
int32 num_zero = 0;
for (int32 i = 0; i < num_updatable; i++) {
if (relative_diff(i) == 0.0) {
num_zero++;
}
}
if (num_zero > 0) {
BaseFloat average_diff = relative_diff.Sum()
/ static_cast<BaseFloat>(num_updatable - num_zero);
for (int32 i = 0; i < num_updatable; i++) {
if (relative_diff(i) == 0.0) {
relative_diff(i) = average_diff;
}
}
KALDI_LOG << "Zeros detected in the relative parameter difference "
<< "vector, updating the vector to " << relative_diff;
}
}
// Gets learning rates for current neural net.
Vector<BaseFloat> cur_nnet_learning_rates(num_updatable);
am_cur_nnet.GetNnet().GetLearningRates(&cur_nnet_learning_rates);
KALDI_LOG << "Old learning rates for current model per layers are "
<< cur_nnet_learning_rates;
// Gets target geometric mean.
BaseFloat target_geo_mean = 0.0;
if (average_learning_rate == 0.0) {
target_geo_mean = exp(cur_nnet_learning_rates.SumLog()
/ static_cast<BaseFloat>(num_updatable));
} else {
target_geo_mean = average_learning_rate;
}
KALDI_ASSERT(target_geo_mean > 0.0);
// Works out the new learning rates.
cur_nnet_learning_rates.DivElements(relative_diff);
BaseFloat cur_geo_mean = exp(cur_nnet_learning_rates.SumLog()
/ static_cast<BaseFloat>(num_updatable));
cur_nnet_learning_rates.Scale(target_geo_mean / cur_geo_mean);
KALDI_LOG << "New learning rates for current model per layers are "
<< cur_nnet_learning_rates;
// Sets learning rates and writes updated model.
am_cur_nnet.GetNnet().SetLearningRates(cur_nnet_learning_rates);
Output ko(modified_cur_nnet_rxfilename, binary_write);
trans_model.Write(ko.Stream(), binary_write);
am_cur_nnet.Write(ko.Stream(), binary_write);
return ret;
} catch(const std::exception &e) {
std::cerr << e.what() << '\n';
return -1;
}
}