Moved the parallel trainign guard when writing model/checkpoint files to the actual save functions instead of guarding at the call sites

This commit is contained in:
Amit Agarwal 2015-09-15 18:54:05 -07:00
Родитель d204e2d11e
Коммит 62558e3079
8 изменённых файлов: 58 добавлений и 53 удалений

Просмотреть файл

@ -247,3 +247,5 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
};
}}}
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;

Просмотреть файл

@ -50,7 +50,7 @@
#include <fileutil.h>
// TODO: Get rid of this global
Microsoft::MSR::CNTK::MPIWrapper *g_mpi;
Microsoft::MSR::CNTK::MPIWrapper *g_mpi = nullptr;
using namespace std;
using namespace Microsoft::MSR;

Просмотреть файл

@ -55,11 +55,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
void ComputationNetwork::SaveToFile(const std::wstring& fileName, const FileOptions fileFormat) const
{
// Saving into temporary file and then renaming it to the requested fileName
// This is a standard trick to avoid havign corrupted model files if process dies during writing
wstring tmpFileName = fileName + L".tmp";
SaveToFileImpl(tmpFileName, fileFormat);
renameOrDie(tmpFileName, fileName);
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
{
// Saving into temporary file and then renaming it to the requested fileName
// This is a standard trick to avoid havign corrupted model files if process dies during writing
wstring tmpFileName = fileName + L".tmp";
SaveToFileImpl(tmpFileName, fileFormat);
renameOrDie(tmpFileName, fileName);
}
}
// TODO: how does the file distinguish float vs double nodes?

Просмотреть файл

@ -27,6 +27,7 @@
#include "ComputationNode.h"
#include "ScriptableObjects.h"
#include "MPIWrapper.h"
//#include "MatrixPool.h"

Просмотреть файл

@ -17,6 +17,9 @@
#endif
#include "BestGpu.h"
// TODO: Get rid of this global
Microsoft::MSR::CNTK::MPIWrapper *g_mpi = nullptr;
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>

Просмотреть файл

@ -50,14 +50,14 @@
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<IncludePath>..\CNTKSGDLib;..\CNTKComputationNetworkLib;..\..\Math\Math;..\..\Common\Include;..\..\BrainScript;$(CUDA_PATH)\include;$(VCInstallDir)include;$(WindowsSDK_IncludePath)</IncludePath>
<LibraryPath>..\CNTKComputationNetworkLib;..\..\Math\Math;$(CUDA_PATH)\lib\$(Platform);$(VCInstallDir)lib\amd64;$(WindowsSDK_LibraryPath_x64);$(Platform)</LibraryPath>
<IncludePath>..\CNTKSGDLib;..\CNTKComputationNetworkLib;..\..\Math\Math;..\..\Common\Include;..\..\BrainScript;C:\Program Files (x86)\Microsoft SDKs\MPI\Include;$(CUDA_PATH)\include;$(VCInstallDir)include;$(WindowsSDK_IncludePath)</IncludePath>
<LibraryPath>..\CNTKComputationNetworkLib;..\..\Math\Math;C:\Program Files (x86)\Microsoft SDKs\MPI\Lib\x64;$(CUDA_PATH)\lib\$(Platform);$(VCInstallDir)lib\amd64;$(WindowsSDK_LibraryPath_x64);$(Platform)</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<IncludePath>..\CNTKSGDLib;..\CNTKComputationNetworkLib;..\..\Math\Math;..\..\Common\Include;..\..\BrainScript;$(CUDA_PATH)\include;$(VCInstallDir)include;$(WindowsSDK_IncludePath)</IncludePath>
<LibraryPath>..\CNTKComputationNetworkLib;..\..\Math\Math;$(CUDA_PATH)\lib\$(Platform);$(VCInstallDir)lib\amd64;$(WindowsSDK_LibraryPath_x64);$(Platform)</LibraryPath>
<IncludePath>..\CNTKSGDLib;..\CNTKComputationNetworkLib;..\..\Math\Math;..\..\Common\Include;..\..\BrainScript;C:\Program Files (x86)\Microsoft SDKs\MPI\Include;$(CUDA_PATH)\include;$(VCInstallDir)include;$(WindowsSDK_IncludePath)</IncludePath>
<LibraryPath>..\CNTKComputationNetworkLib;..\..\Math\Math;C:\Program Files (x86)\Microsoft SDKs\MPI\Lib\x64;$(CUDA_PATH)\lib\$(Platform);$(VCInstallDir)lib\amd64;$(WindowsSDK_LibraryPath_x64);$(Platform)</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">

Просмотреть файл

@ -6,9 +6,6 @@
#include "SGD.h"
//#include "MultiNetworksSGD.h"
#include "AllReduceDistGradAggregator.h"
#include "MPIWrapper.h"
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;
namespace Microsoft { namespace MSR { namespace CNTK {
@ -942,11 +939,7 @@ template<class ElemType>
if (g_mpi != nullptr)
g_mpi->WaitAll();
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
{
// only needs to be done by one process
net.SaveToFile(GetModelNameForEpoch(int(startEpoch) - 1));
}
net.SaveToFile(GetModelNameForEpoch(int(startEpoch) - 1));
}
// first, we need to normalize the effect of nbruttsineachrecurrentiter
@ -1041,9 +1034,8 @@ template<class ElemType>
i + 1, learnRatePerSample, m_minLearnRate);
if (m_autoLearnRateSearchType != LearningRateSearchAlgorithm::None)
{
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
net.SaveToFile(m_modelPath);
}
net.SaveToFile(m_modelPath);
}
break;
}
@ -1209,8 +1201,7 @@ template<class ElemType>
learnRateReduced = true;
else
{
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
net.SaveToFile(GetModelNameForEpoch(i, true));
net.SaveToFile(GetModelNameForEpoch(i, true));
fprintf(stderr, "Finished training and saved final model\n\n");
break;
@ -2490,41 +2481,45 @@ template<class ElemType>
const double prevCriterion,
const size_t minibatchSize)
{
wstring checkPointFileName = GetCheckPointFileNameForEpoch(int(epoch));
// Saving into temporary file and then renaming it to the checkPointFileName
// This is a standard trick to avoid havign corrupted checkpoints files if process dies during writing
wstring tempFileName = checkPointFileName + L".tmp";
// In case of parallel training only the main node should we saving the checkpoint to prevent
// the parallel training nodes from colliding to write the same file
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
{
File fstream(tempFileName,
FileOptions::fileOptionsBinary | FileOptions::fileOptionsWrite);
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCKP");
wstring checkPointFileName = GetCheckPointFileNameForEpoch(int(epoch));
// Saving into temporary file and then renaming it to the checkPointFileName
// This is a standard trick to avoid havign corrupted checkpoints files if process dies during writing
wstring tempFileName = checkPointFileName + L".tmp";
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BLearnRate");
fstream << totalSamplesSeen << learnRatePerSample << prevCriterion;
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ELearnRate");
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BMinibatchSize");
fstream << minibatchSize;
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EMinibatchSize");
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BGradient");
for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++)
{
const Matrix<ElemType>& smoothedGradient = *smoothedGradientIter;
fstream << smoothedGradient;
File fstream(tempFileName, FileOptions::fileOptionsBinary | FileOptions::fileOptionsWrite);
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCKP");
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BLearnRate");
fstream << totalSamplesSeen << learnRatePerSample << prevCriterion;
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ELearnRate");
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BMinibatchSize");
fstream << minibatchSize;
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EMinibatchSize");
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BGradient");
for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++)
{
const Matrix<ElemType>& smoothedGradient = *smoothedGradientIter;
fstream << smoothedGradient;
}
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EGradient");
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECKP");
// Ensuring that data is written
fstream.Flush();
}
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EGradient");
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECKP");
// Ensuring that data is written
fstream.Flush();
renameOrDie(tempFileName, checkPointFileName);
}
renameOrDie(tempFileName, checkPointFileName);
}
template<class ElemType>

Просмотреть файл

@ -366,7 +366,6 @@ CNTK_SRC =\
MachineLearning/CNTKComputationNetworkLib/NetworkBuilderFromConfig.cpp \
MachineLearning/CNTKSGDLib/Profiler.cpp \
MachineLearning/CNTKSGDLib/SGD.cpp \
MachineLearning/CNTKEval/CNTKEval.cpp \
BrainScript/BrainScriptEvaluator.cpp \
BrainScript/BrainScriptParser.cpp \
BrainScript/BrainScriptTest.cpp \