зеркало из https://github.com/mozilla/kaldi.git
Adding nnet-modify-learning-rates.cc
git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/dan2@3239 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Родитель
1d3cc77dbb
Коммит
506c9faf2c
|
@ -22,7 +22,8 @@ BINFILES = nnet-randomize-frames nnet-am-info nnet-init \
|
|||
nnet-get-feature-transform-multi nnet-copy-egs-discriminative \
|
||||
nnet-get-egs-discriminative nnet-shuffle-egs-discriminative \
|
||||
nnet-compare-hash-discriminative nnet-combine-egs-discriminative \
|
||||
nnet-train-discriminative-simple nnet-train-discriminative-parallel
|
||||
nnet-train-discriminative-simple nnet-train-discriminative-parallel \
|
||||
nnet-modify-learning-rates
|
||||
|
||||
OBJFILES =
|
||||
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
// nnet2bin/nnet-modify-learning-rates.cc
|
||||
|
||||
// Copyright 2013 Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "util/common-utils.h"
|
||||
#include "hmm/transition-model.h"
|
||||
#include "nnet2/nnet-randomize.h"
|
||||
#include "nnet2/train-nnet.h"
|
||||
#include "nnet2/am-nnet.h"
|
||||
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi;
|
||||
using namespace kaldi::nnet2;
|
||||
typedef kaldi::int32 int32;
|
||||
typedef kaldi::int64 int64;
|
||||
|
||||
const char *usage =
|
||||
"This program modifies the learning rates so as to equalize the\n"
|
||||
"relative changes in parameters for each layer, while keeping their\n"
|
||||
"geometric mean the same (or changing it to a value specified using\n"
|
||||
"the --average-learning-rate option)."
|
||||
"\n"
|
||||
"Usage: nnet-modify-learning-rates [options] <prev-model> \\\n"
|
||||
" <cur-model> <modified-cur-model>\n"
|
||||
"e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n"
|
||||
" 5.mdl 6.mdl 6.mdl\n";
|
||||
|
||||
bool binary_write = true;
|
||||
BaseFloat average_learning_rate = 0.0;
|
||||
std::string use_gpu = "optional";
|
||||
|
||||
ParseOptions po(usage);
|
||||
po.Register("binary", &binary_write, "Write output in binary mode");
|
||||
po.Register("average-learning-rate", &average_learning_rate,
|
||||
"If supplied, change learning rate geometric mean to the given "
|
||||
"value.");
|
||||
po.Register("use-gpu", &use_gpu,
|
||||
"yes|no|optional, only has effect if compiled with CUDA");
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 3) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
KALDI_ASSERT(average_learning_rate >= 0);
|
||||
|
||||
#if HAVE_CUDA==1
|
||||
CuDevice::Instantiate().SelectGpuId(use_gpu);
|
||||
#endif
|
||||
|
||||
std::string prev_nnet_rxfilename = po.GetArg(1),
|
||||
cur_nnet_rxfilename = po.GetArg(2),
|
||||
modified_cur_nnet_rxfilename = po.GetOptArg(3);
|
||||
|
||||
TransitionModel trans_model;
|
||||
AmNnet am_prev_nnet, am_cur_nnet;
|
||||
{
|
||||
bool binary_read;
|
||||
Input ki(prev_nnet_rxfilename, &binary_read);
|
||||
trans_model.Read(ki.Stream(), binary_read);
|
||||
am_prev_nnet.Read(ki.Stream(), binary_read);
|
||||
}
|
||||
{
|
||||
bool binary_read;
|
||||
Input ki(cur_nnet_rxfilename, &binary_read);
|
||||
trans_model.Read(ki.Stream(), binary_read);
|
||||
am_cur_nnet.Read(ki.Stream(), binary_read);
|
||||
}
|
||||
|
||||
if (am_prev_nnet.GetNnet().GetParameterDim() !=
|
||||
am_cur_nnet.GetNnet().GetParameterDim()) {
|
||||
KALDI_WARN << "Parameter-dim mismatch, cannot equalize the relative "
|
||||
<< "changes in parameters for each layer.";
|
||||
exit(0);
|
||||
}
|
||||
|
||||
int32 ret = 0;
|
||||
|
||||
// Gets relative parameter differences.
|
||||
int32 num_updatable = am_prev_nnet.GetNnet().NumUpdatableComponents();
|
||||
Vector<BaseFloat> relative_diff(num_updatable);
|
||||
{
|
||||
Nnet diff_nnet(am_prev_nnet.GetNnet());
|
||||
diff_nnet.AddNnet(-1.0, am_cur_nnet.GetNnet());
|
||||
diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff);
|
||||
relative_diff.ApplyPow(0.5);
|
||||
Vector<BaseFloat> baseline_prod(num_updatable);
|
||||
am_prev_nnet.GetNnet().ComponentDotProducts(am_prev_nnet.GetNnet(),
|
||||
&baseline_prod);
|
||||
baseline_prod.ApplyPow(0.5);
|
||||
relative_diff.DivElements(baseline_prod);
|
||||
KALDI_LOG << "Relative parameter differences per layer are "
|
||||
<< relative_diff;
|
||||
|
||||
// If relative parameter difference for a certain is zero, set it to the
|
||||
// mean of the rest values.
|
||||
int32 num_zero = 0;
|
||||
for (int32 i = 0; i < num_updatable; i++) {
|
||||
if (relative_diff(i) == 0.0) {
|
||||
num_zero++;
|
||||
}
|
||||
}
|
||||
if (num_zero > 0) {
|
||||
BaseFloat average_diff = relative_diff.Sum()
|
||||
/ static_cast<BaseFloat>(num_updatable - num_zero);
|
||||
for (int32 i = 0; i < num_updatable; i++) {
|
||||
if (relative_diff(i) == 0.0) {
|
||||
relative_diff(i) = average_diff;
|
||||
}
|
||||
}
|
||||
KALDI_LOG << "Zeros detected in the relative parameter difference "
|
||||
<< "vector, updating the vector to " << relative_diff;
|
||||
}
|
||||
}
|
||||
|
||||
// Gets learning rates for current neural net.
|
||||
Vector<BaseFloat> cur_nnet_learning_rates(num_updatable);
|
||||
am_cur_nnet.GetNnet().GetLearningRates(&cur_nnet_learning_rates);
|
||||
KALDI_LOG << "Old learning rates for current model per layers are "
|
||||
<< cur_nnet_learning_rates;
|
||||
|
||||
// Gets target geometric mean.
|
||||
BaseFloat target_geo_mean = 0.0;
|
||||
if (average_learning_rate == 0.0) {
|
||||
target_geo_mean = exp(cur_nnet_learning_rates.SumLog()
|
||||
/ static_cast<BaseFloat>(num_updatable));
|
||||
} else {
|
||||
target_geo_mean = average_learning_rate;
|
||||
}
|
||||
KALDI_ASSERT(target_geo_mean > 0.0);
|
||||
|
||||
// Works out the new learning rates.
|
||||
cur_nnet_learning_rates.DivElements(relative_diff);
|
||||
BaseFloat cur_geo_mean = exp(cur_nnet_learning_rates.SumLog()
|
||||
/ static_cast<BaseFloat>(num_updatable));
|
||||
cur_nnet_learning_rates.Scale(target_geo_mean / cur_geo_mean);
|
||||
KALDI_LOG << "New learning rates for current model per layers are "
|
||||
<< cur_nnet_learning_rates;
|
||||
|
||||
// Sets learning rates and writes updated model.
|
||||
am_cur_nnet.GetNnet().SetLearningRates(cur_nnet_learning_rates);
|
||||
Output ko(modified_cur_nnet_rxfilename, binary_write);
|
||||
trans_model.Write(ko.Stream(), binary_write);
|
||||
am_cur_nnet.Write(ko.Stream(), binary_write);
|
||||
|
||||
return ret;
|
||||
} catch(const std::exception &e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче