When nbruttsineachrecurrentite is larger than 1, learnratePerSample is normalized
This commit is contained in:
Родитель
1a6b330048
Коммит
b8dd80c67c
|
@ -122,6 +122,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
SGD(const ConfigParameters& configSGD)
|
||||
{
|
||||
ConfigArray learningRatesPerMBStr = configSGD("learningRatesPerMB", "");
|
||||
m_needToNormalizeLRByParallUtterance = false;
|
||||
floatargvector learningRatesPerMB = learningRatesPerMBStr;
|
||||
|
||||
ConfigArray learningRatesPerSampleStr = configSGD("learningRatesPerSample", "");
|
||||
|
@ -302,6 +303,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
m_learningRatesPerSample[i] = learningRatesPerMB[i]/m_mbSize[i];
|
||||
}
|
||||
m_needToNormalizeLRByParallUtterance = true;
|
||||
}
|
||||
m_momentumPerMB = 0.9f;
|
||||
if (momentumPerMB.size() >0)
|
||||
|
@ -525,6 +527,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (0 == myRank) // only needs to be done by one process
|
||||
net.SaveToFile(GetModelNameForEpoch(int(startEpoch) - 1));
|
||||
|
||||
// first, we need to normalize the effect of nbruttsineachrecurrentiter
|
||||
if (trainSetDataReader->NumberSlicesInEachRecurrentIter()>1 && m_needToNormalizeLRByParallUtterance)
|
||||
{
|
||||
for (auto & x : m_learningRatesPerSample)
|
||||
{
|
||||
x /= trainSetDataReader->NumberSlicesInEachRecurrentIter();
|
||||
}
|
||||
}
|
||||
bool learnRateInitialized = false;
|
||||
if (startEpoch > 0)
|
||||
{
|
||||
|
@ -571,6 +581,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
SetDropoutRate(net, criterionNodes[0], m_dropoutRates[i], prevDropoutRate, dropOutSeed);
|
||||
|
||||
//learning rate adjustment
|
||||
|
||||
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::None || (m_learningRatesPerSample.size() > 0 && m_learningRatesPerSample.size() > i))
|
||||
{
|
||||
learnRatePerSample = m_learningRatesPerSample[i];
|
||||
|
@ -1515,6 +1526,7 @@ protected:
|
|||
protected:
|
||||
|
||||
floatargvector m_learningRatesPerSample; /// learning rate per sample provided outside
|
||||
bool m_needToNormalizeLRByParallUtterance; // only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1
|
||||
intargvector m_mbSize;
|
||||
size_t m_epochSize;
|
||||
size_t m_maxEpochs;
|
||||
|
|
Загрузка…
Ссылка в новой задаче