Change the default learnRateAdjustInterval to 1 instead of 2.

This commit is contained in:
Yu 2015-02-05 15:28:37 -05:00
Родитель e9b031d6ea
Коммит 2550040c2d
1 изменённых файлов: 3 добавлений и 4 удалений

Просмотреть файл

@ -255,7 +255,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_learnRateIncreaseFactor=learnRateIncreaseFactor; m_learnRateIncreaseFactor=learnRateIncreaseFactor;
m_reduceLearnRateIfImproveLessThan=reduceLearnRateIfImproveLessThan; m_reduceLearnRateIfImproveLessThan=reduceLearnRateIfImproveLessThan;
m_continueReduce=continueReduce; m_continueReduce=continueReduce;
m_learnRateAdjustInterval = max((size_t) 2, learnRateAdjustInterval); //minimum interval is 1 epoch m_learnRateAdjustInterval = max((size_t) 1, learnRateAdjustInterval); //minimum interval is 1 epoch
m_learnRateDecreaseFactor=learnRateDecreaseFactor; m_learnRateDecreaseFactor=learnRateDecreaseFactor;
m_clippingThresholdPerSample=abs(clippingThresholdPerSample); m_clippingThresholdPerSample=abs(clippingThresholdPerSample);
m_numMiniBatch4LRSearch=numMiniBatch4LRSearch; m_numMiniBatch4LRSearch=numMiniBatch4LRSearch;
@ -1441,12 +1441,11 @@ protected:
icol = max(0, icol); icol = max(0, icol);
fprintf(stderr, "\n###### d%ls######\n", node->NodeName().c_str()); fprintf(stderr, "\n###### d%ls######\n", node->NodeName().c_str());
// node->FunctionValues().Print(); //node->FunctionValues().Print();
ElemType eOrg = node->FunctionValues()(irow,icol); ElemType eOrg = node->FunctionValues()(irow,icol);
node->UpdateEvalTimeStamp(); node->UpdateEvalTimeStamp();
net.ComputeGradient(criterionNodes[npos]); //use only the first criterion. Is net.ComputeGradient(criterionNodes[npos]); //use only the first criterion. Is
//ElemType mbEvalCri =
criterionNodes[npos]->FunctionValues().Get00Element(); //criterionNode should be a scalar criterionNodes[npos]->FunctionValues().Get00Element(); //criterionNode should be a scalar
ElemType eGradErr = node->GradientValues()(irow, icol); ElemType eGradErr = node->GradientValues()(irow, icol);
@ -1473,7 +1472,7 @@ protected:
bool wrong = (std::isnan(diff) || diff > threshold); bool wrong = (std::isnan(diff) || diff > threshold);
if (wrong) if (wrong)
{ {
fprintf (stderr, "\nd%ls Numeric gradient = %e, Error BP gradient = %e\n", node->NodeName().c_str(), eGradNum, eGradErr); fprintf (stderr, "\nd%ls Numeric gradient = %e, Error BP gradient = %e \n", node->NodeName().c_str(), eGradNum, eGradErr);
return false; return false;
} }
} }