From dea7df7e2be311dec5b34ffc548bc374818d0815 Mon Sep 17 00:00:00 2001 From: Yuchen Fan Date: Mon, 23 Nov 2015 15:05:43 +0800 Subject: [PATCH] Fix bugs when loadBestModel in AdjustAfterEpoch --- MachineLearning/CNTKSGDLib/SGD.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/MachineLearning/CNTKSGDLib/SGD.cpp b/MachineLearning/CNTKSGDLib/SGD.cpp index bce25dd0f..6a6b1c80f 100644 --- a/MachineLearning/CNTKSGDLib/SGD.cpp +++ b/MachineLearning/CNTKSGDLib/SGD.cpp @@ -901,10 +901,10 @@ namespace Microsoft { namespace MSR { namespace CNTK { if (m_loadBestModel) { fprintf(stderr, "Loaded the previous model which has better training criterion.\n"); - net->LoadPersistableParametersFromFile(GetModelNameForEpoch(i - 1), + net->LoadPersistableParametersFromFile(GetModelNameForEpoch(i - m_learnRateAdjustInterval), m_validateAfterModelReloading); net->ResetEvalTimeStamp(); - LoadCheckPointInfo(i - 1, + LoadCheckPointInfo(i - m_learnRateAdjustInterval, /*out*/ totalSamplesSeen, /*out*/ learnRatePerSample, smoothedGradients, @@ -984,7 +984,21 @@ namespace Microsoft { namespace MSR { namespace CNTK { if (!m_keepCheckPointFiles) { // delete previous checkpoint file to save space - _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); + if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_loadBestModel) + { + if (epochsSinceLastLearnRateAdjust != 1) + { + _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); + } + if (epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) + { + _wunlink(GetCheckPointFileNameForEpoch(i - m_learnRateAdjustInterval).c_str()); + } + } + else + { + _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); + } } }