Fix error when setting perMB momentum.

This commit is contained in:
Yu 2015-07-17 17:48:03 -04:00
Родитель 4827a728e9
Коммит b5b1c13d68
1 изменённых файлов: 20 добавлений и 1 удалений

Просмотреть файл

@ -143,6 +143,7 @@ public:
{
ConfigArray learningRatesPerMBStr = configSGD("learningRatesPerMB", "");
m_needToNormalizeLRByParallUtterance = false;
m_needToNormalizeMomentumByParallUtterance = false;
floatargvector learningRatesPerMB = learningRatesPerMBStr;
ConfigArray learningRatesPerSampleStr = configSGD("learningRatesPerSample", "");
@ -437,6 +438,8 @@ public:
}
m_momentumPerSample[i] = (float)pow(momentumPerMB[i], 1.0 / m_mbSize[i]);
}
m_needToNormalizeMomentumByParallUtterance = true;
}
else
{
@ -770,6 +773,15 @@ protected:
x /= trainSetDataReader->NumberSlicesInEachRecurrentIter();
}
}
// first, we need to normalize the effect of nbruttsineachrecurrentiter for momemtum
if (trainSetDataReader->NumberSlicesInEachRecurrentIter() > 1 && m_needToNormalizeMomentumByParallUtterance)
{
for (auto& x : m_momentumPerSample)
{
x = (float)pow(x, 1.0 / trainSetDataReader->NumberSlicesInEachRecurrentIter());
}
}
bool learnRateInitialized = false;
if (startEpoch > 0)
@ -857,6 +869,7 @@ protected:
INT32 mySamples = (INT32)
#endif
size_t chosenMinibatchSize;
size_t actualMinibatchSize;
// Through the command line or config file the user can set minibatch sizes on a per epoch
// basis for a set number of epochs. For epochs after that point, m_mbSize.size(), either
@ -884,10 +897,15 @@ protected:
{
// use the explicitly set minibatch size
chosenMinibatchSize = m_mbSize[i];
if (trainSetDataReader->NumberSlicesInEachRecurrentIter() > 1 && m_needToNormalizeMomentumByParallUtterance)
{
actualMinibatchSize = chosenMinibatchSize * trainSetDataReader->NumberSlicesInEachRecurrentIter();
}
}
fprintf(stderr, "Starting Epoch %d: learning rate per sample = %f momentum = %f \n",
i + 1, learnRatePerSample, MomentumPerMB(m_momentumPerSample[i], chosenMinibatchSize));
i + 1, learnRatePerSample, MomentumPerMB(m_momentumPerSample[i], actualMinibatchSize));
TrainOneEpoch(net,
refNet,
@ -2310,6 +2328,7 @@ protected:
// only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1
bool m_needToNormalizeLRByParallUtterance;
bool m_needToNormalizeMomentumByParallUtterance;
intargvector m_mbSize;