When nbruttsineachrecurrentite is larger than 1, learnratePerSample is normalized

This commit is contained in:
erw 2015-01-21 13:57:24 -08:00
Родитель 1a6b330048
Коммит b8dd80c67c
1 изменённых файлов: 12 добавлений и 0 удалений

Просмотреть файл

@ -122,6 +122,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
SGD(const ConfigParameters& configSGD)
{
ConfigArray learningRatesPerMBStr = configSGD("learningRatesPerMB", "");
m_needToNormalizeLRByParallUtterance = false;
floatargvector learningRatesPerMB = learningRatesPerMBStr;
ConfigArray learningRatesPerSampleStr = configSGD("learningRatesPerSample", "");
@ -302,6 +303,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
m_learningRatesPerSample[i] = learningRatesPerMB[i]/m_mbSize[i];
}
m_needToNormalizeLRByParallUtterance = true;
}
m_momentumPerMB = 0.9f;
if (momentumPerMB.size() >0)
@ -525,6 +527,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (0 == myRank) // only needs to be done by one process
net.SaveToFile(GetModelNameForEpoch(int(startEpoch) - 1));
// first, we need to normalize the effect of nbruttsineachrecurrentiter
if (trainSetDataReader->NumberSlicesInEachRecurrentIter()>1 && m_needToNormalizeLRByParallUtterance)
{
for (auto & x : m_learningRatesPerSample)
{
x /= trainSetDataReader->NumberSlicesInEachRecurrentIter();
}
}
bool learnRateInitialized = false;
if (startEpoch > 0)
{
@ -571,6 +581,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
SetDropoutRate(net, criterionNodes[0], m_dropoutRates[i], prevDropoutRate, dropOutSeed);
//learning rate adjustment
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::None || (m_learningRatesPerSample.size() > 0 && m_learningRatesPerSample.size() > i))
{
learnRatePerSample = m_learningRatesPerSample[i];
@ -1515,6 +1526,7 @@ protected:
protected:
floatargvector m_learningRatesPerSample; /// learning rate per sample provided outside
bool m_needToNormalizeLRByParallUtterance; // only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1
intargvector m_mbSize;
size_t m_epochSize;
size_t m_maxEpochs;