From 506c9faf2c7a85a03abe36952d984d106cb7303c Mon Sep 17 00:00:00 2001 From: Guoguo Chen Date: Fri, 29 Nov 2013 06:40:40 +0000 Subject: [PATCH] Adding nnet-modify-learning-rates.cc git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/dan2@3239 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8 --- src/nnet2bin/Makefile | 3 +- src/nnet2bin/nnet-modify-learning-rates.cc | 171 +++++++++++++++++++++ 2 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 src/nnet2bin/nnet-modify-learning-rates.cc diff --git a/src/nnet2bin/Makefile b/src/nnet2bin/Makefile index 288cdee59..48d9f869a 100644 --- a/src/nnet2bin/Makefile +++ b/src/nnet2bin/Makefile @@ -22,7 +22,8 @@ BINFILES = nnet-randomize-frames nnet-am-info nnet-init \ nnet-get-feature-transform-multi nnet-copy-egs-discriminative \ nnet-get-egs-discriminative nnet-shuffle-egs-discriminative \ nnet-compare-hash-discriminative nnet-combine-egs-discriminative \ - nnet-train-discriminative-simple nnet-train-discriminative-parallel + nnet-train-discriminative-simple nnet-train-discriminative-parallel \ + nnet-modify-learning-rates OBJFILES = diff --git a/src/nnet2bin/nnet-modify-learning-rates.cc b/src/nnet2bin/nnet-modify-learning-rates.cc new file mode 100644 index 000000000..a3805b327 --- /dev/null +++ b/src/nnet2bin/nnet-modify-learning-rates.cc @@ -0,0 +1,171 @@ +// nnet2bin/nnet-modify-learning-rates.cc + +// Copyright 2013 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "nnet2/nnet-randomize.h" +#include "nnet2/train-nnet.h" +#include "nnet2/am-nnet.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet2; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "This program modifies the learning rates so as to equalize the\n" + "relative changes in parameters for each layer, while keeping their\n" + "geometric mean the same (or changing it to a value specified using\n" + "the --average-learning-rate option)." + "\n" + "Usage: nnet-modify-learning-rates [options] \\\n" + " \n" + "e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n" + " 5.mdl 6.mdl 6.mdl\n"; + + bool binary_write = true; + BaseFloat average_learning_rate = 0.0; + std::string use_gpu = "optional"; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("average-learning-rate", &average_learning_rate, + "If supplied, change learning rate geometric mean to the given " + "value."); + po.Register("use-gpu", &use_gpu, + "yes|no|optional, only has effect if compiled with CUDA"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + KALDI_ASSERT(average_learning_rate >= 0); + +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string prev_nnet_rxfilename = po.GetArg(1), + cur_nnet_rxfilename = po.GetArg(2), + modified_cur_nnet_rxfilename = po.GetOptArg(3); + + TransitionModel trans_model; + AmNnet am_prev_nnet, am_cur_nnet; + { + bool binary_read; + Input ki(prev_nnet_rxfilename, &binary_read); + trans_model.Read(ki.Stream(), binary_read); + am_prev_nnet.Read(ki.Stream(), binary_read); + } + { + bool binary_read; + Input ki(cur_nnet_rxfilename, &binary_read); + trans_model.Read(ki.Stream(), binary_read); + am_cur_nnet.Read(ki.Stream(), binary_read); + } + + if (am_prev_nnet.GetNnet().GetParameterDim() != + am_cur_nnet.GetNnet().GetParameterDim()) { + KALDI_WARN << "Parameter-dim mismatch, cannot equalize the relative " + << "changes in parameters for each layer."; + exit(0); + } + + int32 ret = 0; + + // Gets relative parameter differences. + int32 num_updatable = am_prev_nnet.GetNnet().NumUpdatableComponents(); + Vector relative_diff(num_updatable); + { + Nnet diff_nnet(am_prev_nnet.GetNnet()); + diff_nnet.AddNnet(-1.0, am_cur_nnet.GetNnet()); + diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff); + relative_diff.ApplyPow(0.5); + Vector baseline_prod(num_updatable); + am_prev_nnet.GetNnet().ComponentDotProducts(am_prev_nnet.GetNnet(), + &baseline_prod); + baseline_prod.ApplyPow(0.5); + relative_diff.DivElements(baseline_prod); + KALDI_LOG << "Relative parameter differences per layer are " + << relative_diff; + + // If relative parameter difference for a certain is zero, set it to the + // mean of the rest values. + int32 num_zero = 0; + for (int32 i = 0; i < num_updatable; i++) { + if (relative_diff(i) == 0.0) { + num_zero++; + } + } + if (num_zero > 0) { + BaseFloat average_diff = relative_diff.Sum() + / static_cast(num_updatable - num_zero); + for (int32 i = 0; i < num_updatable; i++) { + if (relative_diff(i) == 0.0) { + relative_diff(i) = average_diff; + } + } + KALDI_LOG << "Zeros detected in the relative parameter difference " + << "vector, updating the vector to " << relative_diff; + } + } + + // Gets learning rates for current neural net. + Vector cur_nnet_learning_rates(num_updatable); + am_cur_nnet.GetNnet().GetLearningRates(&cur_nnet_learning_rates); + KALDI_LOG << "Old learning rates for current model per layers are " + << cur_nnet_learning_rates; + + // Gets target geometric mean. + BaseFloat target_geo_mean = 0.0; + if (average_learning_rate == 0.0) { + target_geo_mean = exp(cur_nnet_learning_rates.SumLog() + / static_cast(num_updatable)); + } else { + target_geo_mean = average_learning_rate; + } + KALDI_ASSERT(target_geo_mean > 0.0); + + // Works out the new learning rates. + cur_nnet_learning_rates.DivElements(relative_diff); + BaseFloat cur_geo_mean = exp(cur_nnet_learning_rates.SumLog() + / static_cast(num_updatable)); + cur_nnet_learning_rates.Scale(target_geo_mean / cur_geo_mean); + KALDI_LOG << "New learning rates for current model per layers are " + << cur_nnet_learning_rates; + + // Sets learning rates and writes updated model. + am_cur_nnet.GetNnet().SetLearningRates(cur_nnet_learning_rates); + Output ko(modified_cur_nnet_rxfilename, binary_write); + trans_model.Write(ko.Stream(), binary_write); + am_cur_nnet.Write(ko.Stream(), binary_write); + + return ret; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +}