Bugfix: Disable parallelization for commands other than ones for train, test and cv
This commit is contained in:
Родитель
81a83b7e4b
Коммит
877725d15e
|
@ -1 +1 @@
|
||||||
Subproject commit f785679a6bd5cc089b138b3c6bcb68e4b1f345ae
|
Subproject commit 41c1f55b9d5115c4dd051391f38eed8e93fb1860
|
|
@ -59,8 +59,6 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
||||||
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
||||||
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
||||||
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
||||||
//TODO: switch to a global parallel setting for both training and evaluation.
|
|
||||||
bool useParallel = config(L"parallelTrain", false);
|
|
||||||
|
|
||||||
ConfigArray evalNodeNames = config(L"evalNodeNames", "");
|
ConfigArray evalNodeNames = config(L"evalNodeNames", "");
|
||||||
vector<wstring> evalNodeNamesVector;
|
vector<wstring> evalNodeNamesVector;
|
||||||
|
@ -71,7 +69,7 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
||||||
|
|
||||||
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);
|
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);
|
||||||
|
|
||||||
SimpleEvaluator<ElemType> eval(net, useParallel, numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
||||||
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,8 +118,6 @@ void DoCrossValidate(const ConfigParameters& config)
|
||||||
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
||||||
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
||||||
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
||||||
//TODO: switch to a global parallel setting for both training and evaluation.
|
|
||||||
bool useParallel = config(L"parallelTrain", false);
|
|
||||||
|
|
||||||
ConfigArray evalNodeNames = config(L"evalNodeNames", "");
|
ConfigArray evalNodeNames = config(L"evalNodeNames", "");
|
||||||
vector<wstring> evalNodeNamesVector;
|
vector<wstring> evalNodeNamesVector;
|
||||||
|
@ -155,7 +151,7 @@ void DoCrossValidate(const ConfigParameters& config)
|
||||||
cvModels.push_back(cvModelPath);
|
cvModels.push_back(cvModelPath);
|
||||||
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, cvModelPath);
|
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, cvModelPath);
|
||||||
|
|
||||||
SimpleEvaluator<ElemType> eval(net, useParallel, numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
||||||
|
|
||||||
fprintf(stderr, "model %ls --> \n", cvModelPath.c_str());
|
fprintf(stderr, "model %ls --> \n", cvModelPath.c_str());
|
||||||
auto evalErrors = eval.Evaluate(&cvDataReader, evalNodeNamesVector, mbSize[0], epochSize);
|
auto evalErrors = eval.Evaluate(&cvDataReader, evalNodeNamesVector, mbSize[0], epochSize);
|
||||||
|
|
|
@ -172,6 +172,7 @@ void DoTrain(const ConfigRecordType& config)
|
||||||
optimizer = make_shared<SGD<ElemType>>(configSGD);
|
optimizer = make_shared<SGD<ElemType>>(configSGD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
optimizer->InitMPI(MPIWrapper::GetInstance());
|
||||||
optimizer->Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
|
optimizer->Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,6 +242,7 @@ void DoAdapt(const ConfigParameters& config)
|
||||||
|
|
||||||
SGD<ElemType> sgd(configSGD);
|
SGD<ElemType> sgd(configSGD);
|
||||||
|
|
||||||
|
sgd.InitMPI(MPIWrapper::GetInstance());
|
||||||
sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
|
sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,9 +56,6 @@
|
||||||
#define let const auto
|
#define let const auto
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// TODO: Get rid of these globals
|
|
||||||
Microsoft::MSR::CNTK::MPIWrapper* g_mpi = nullptr;
|
|
||||||
|
|
||||||
// TODO: Temporary mechanism to enable memory sharing for
|
// TODO: Temporary mechanism to enable memory sharing for
|
||||||
// node output value matrices. This will go away when the
|
// node output value matrices. This will go away when the
|
||||||
// sharing is ready to be enabled by default
|
// sharing is ready to be enabled by default
|
||||||
|
@ -154,9 +151,13 @@ static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const Con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
|
||||||
|
// be run in parallel across multiple ranks. Others should only run on rank 0
|
||||||
|
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest" };
|
||||||
|
|
||||||
// process the command
|
// process the command
|
||||||
template <typename ElemType>
|
template <typename ElemType>
|
||||||
void DoCommands(const ConfigParameters& config)
|
void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mpi)
|
||||||
{
|
{
|
||||||
ConfigArray command = config(L"command", "train");
|
ConfigArray command = config(L"command", "train");
|
||||||
|
|
||||||
|
@ -197,7 +198,7 @@ void DoCommands(const ConfigParameters& config)
|
||||||
std::cerr << "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : " << fullTotalMaxEpochs << endl;
|
std::cerr << "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : " << fullTotalMaxEpochs << endl;
|
||||||
|
|
||||||
// set up progress tracing for compute cluster management
|
// set up progress tracing for compute cluster management
|
||||||
if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode()))
|
if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
|
||||||
{
|
{
|
||||||
ProgressTracing::SetTracingFlag();
|
ProgressTracing::SetTracingFlag();
|
||||||
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
|
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
|
||||||
|
@ -212,7 +213,7 @@ void DoCommands(const ConfigParameters& config)
|
||||||
ConfigParameters commandParams(config(command[i]));
|
ConfigParameters commandParams(config(command[i]));
|
||||||
ConfigArray action = commandParams("action", "train");
|
ConfigArray action = commandParams("action", "train");
|
||||||
|
|
||||||
if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode()))
|
if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
|
||||||
{
|
{
|
||||||
ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to
|
ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to
|
||||||
}
|
}
|
||||||
|
@ -231,74 +232,81 @@ void DoCommands(const ConfigParameters& config)
|
||||||
fprintf(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
|
fprintf(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
|
||||||
fprintf(stderr, "%s\n\n", delim);
|
fprintf(stderr, "%s\n\n", delim);
|
||||||
|
|
||||||
if (thisAction == "train" || thisAction == "trainRNN")
|
if ((mpi == nullptr) || (commandstoRunOnAllRanks.find(thisAction) != commandstoRunOnAllRanks.end()) || mpi->IsMainNode())
|
||||||
{
|
{
|
||||||
std::cerr << "CNTKCommandTrainBegin: " + command[i] << endl;
|
if (thisAction == "train" || thisAction == "trainRNN")
|
||||||
DoTrain<ConfigParameters, ElemType>(commandParams);
|
{
|
||||||
std::cerr << "CNTKCommandTrainEnd: " + command[i] << endl;
|
std::cerr << "CNTKCommandTrainBegin: " + command[i] << endl;
|
||||||
fullEpochsOffset += GetMaxEpochs(commandParams);
|
DoTrain<ConfigParameters, ElemType>(commandParams);
|
||||||
}
|
std::cerr << "CNTKCommandTrainEnd: " + command[i] << endl;
|
||||||
else if (thisAction == "adapt")
|
fullEpochsOffset += GetMaxEpochs(commandParams);
|
||||||
{
|
}
|
||||||
DoAdapt<ElemType>(commandParams);
|
else if (thisAction == "adapt")
|
||||||
}
|
{
|
||||||
else if (thisAction == "test" || thisAction == "eval")
|
DoAdapt<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoEval<ElemType>(commandParams);
|
else if (thisAction == "test" || thisAction == "eval")
|
||||||
}
|
{
|
||||||
else if (thisAction == "edit")
|
DoEval<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoEdit<ElemType>(commandParams);
|
else if (thisAction == "edit")
|
||||||
}
|
{
|
||||||
else if (thisAction == "cv")
|
DoEdit<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoCrossValidate<ElemType>(commandParams);
|
else if (thisAction == "cv")
|
||||||
}
|
{
|
||||||
else if (thisAction == "write")
|
DoCrossValidate<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoWriteOutput<ElemType>(commandParams);
|
else if (thisAction == "write")
|
||||||
}
|
{
|
||||||
else if (thisAction == "devtest")
|
DoWriteOutput<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
TestCn<ElemType>(config); // for "devtest" action pass the root config instead
|
else if (thisAction == "devtest")
|
||||||
}
|
{
|
||||||
else if (thisAction == "dumpnode")
|
TestCn<ElemType>(config); // for "devtest" action pass the root config instead
|
||||||
{
|
}
|
||||||
DumpNodeInfo<ElemType>(commandParams);
|
else if (thisAction == "dumpnode")
|
||||||
}
|
{
|
||||||
else if (thisAction == "convertdbn")
|
DumpNodeInfo<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoConvertFromDbn<ElemType>(commandParams);
|
else if (thisAction == "convertdbn")
|
||||||
}
|
{
|
||||||
else if (thisAction == "exportdbn")
|
DoConvertFromDbn<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoExportToDbn<ElemType>(commandParams);
|
else if (thisAction == "exportdbn")
|
||||||
}
|
{
|
||||||
else if (thisAction == "createLabelMap")
|
DoExportToDbn<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoCreateLabelMap<ElemType>(commandParams);
|
else if (thisAction == "createLabelMap")
|
||||||
}
|
{
|
||||||
else if (thisAction == "writeWordAndClass")
|
DoCreateLabelMap<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoWriteWordAndClassInfo<ElemType>(commandParams);
|
else if (thisAction == "writeWordAndClass")
|
||||||
}
|
{
|
||||||
else if (thisAction == "plot")
|
DoWriteWordAndClassInfo<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoTopologyPlot<ElemType>(commandParams);
|
else if (thisAction == "plot")
|
||||||
}
|
{
|
||||||
else if (thisAction == "SVD")
|
DoTopologyPlot<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
DoParameterSVD<ElemType>(commandParams);
|
else if (thisAction == "SVD")
|
||||||
}
|
{
|
||||||
else
|
DoParameterSVD<ElemType>(commandParams);
|
||||||
{
|
}
|
||||||
RuntimeError("unknown action: %s in command set: %s", thisAction.c_str(), command[i].c_str());
|
else
|
||||||
|
{
|
||||||
|
RuntimeError("unknown action: %s in command set: %s", thisAction.c_str(), command[i].c_str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, "\nAction \"%s\" complete.\n\n", thisAction.c_str());
|
fprintf(stderr, "\nAction \"%s\" complete.\n\n", thisAction.c_str());
|
||||||
|
|
||||||
NDLScript<ElemType> ndlScript;
|
NDLScript<ElemType> ndlScript;
|
||||||
ndlScript.ClearGlobal(); // clear global macros between commands
|
ndlScript.ClearGlobal(); // clear global macros between commands
|
||||||
|
|
||||||
|
// Synchronize all ranks before proceeding to next action/command
|
||||||
|
if (mpi != nullptr)
|
||||||
|
mpi->WaitAll();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -459,10 +467,10 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
|
||||||
InvalidArgument("Legacy name 'type' no longer allowed. Use 'precision'.");
|
InvalidArgument("Legacy name 'type' no longer allowed. Use 'precision'.");
|
||||||
|
|
||||||
// parallel training
|
// parallel training
|
||||||
g_mpi = nullptr;
|
shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> mpi;
|
||||||
bool paralleltrain = config(L"parallelTrain", false);
|
bool paralleltrain = config(L"parallelTrain", false);
|
||||||
if (paralleltrain)
|
if (paralleltrain)
|
||||||
g_mpi = new MPIWrapper();
|
mpi = MPIWrapper::GetInstance(true /*create*/);
|
||||||
|
|
||||||
g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false);
|
g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false);
|
||||||
|
|
||||||
|
@ -476,7 +484,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
|
||||||
logpath += L".log"; // TODO: why do we need to append this here?
|
logpath += L".log"; // TODO: why do we need to append this here?
|
||||||
|
|
||||||
if (paralleltrain)
|
if (paralleltrain)
|
||||||
logpath += msra::strfun::wstrprintf(L"rank%d", (int) g_mpi->CurrentNodeRank());
|
logpath += msra::strfun::wstrprintf(L"rank%d", (int) mpi->CurrentNodeRank());
|
||||||
|
|
||||||
RedirectStdErr(logpath);
|
RedirectStdErr(logpath);
|
||||||
fprintf(stderr, "%ls\n", startupMessage.c_str());
|
fprintf(stderr, "%ls\n", startupMessage.c_str());
|
||||||
|
@ -495,7 +503,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
|
||||||
bool progressTracing = config(L"progressTracing", false);
|
bool progressTracing = config(L"progressTracing", false);
|
||||||
size_t fullTotalMaxEpochs = 1; // BUGBUG: BS does not allow me to read out the max epochs parameters, as that would instantiate and thus execute the objects
|
size_t fullTotalMaxEpochs = 1; // BUGBUG: BS does not allow me to read out the max epochs parameters, as that would instantiate and thus execute the objects
|
||||||
// set up progress tracing for compute cluster management
|
// set up progress tracing for compute cluster management
|
||||||
if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode()))
|
if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
|
||||||
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
|
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
|
||||||
|
|
||||||
// MAIN LOOP that executes the actions
|
// MAIN LOOP that executes the actions
|
||||||
|
@ -508,6 +516,8 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
|
||||||
const ScriptableObjects::ConfigArray& actions = actionsVal;
|
const ScriptableObjects::ConfigArray& actions = actionsVal;
|
||||||
for (int i = actions.GetIndexRange().first; i <= actions.GetIndexRange().second; i++)
|
for (int i = actions.GetIndexRange().first; i <= actions.GetIndexRange().second; i++)
|
||||||
{
|
{
|
||||||
|
// TODO: When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
|
||||||
|
// be run in parallel across multiple ranks. Others should only run on rank 0
|
||||||
actions.At(i, [](const wstring&)
|
actions.At(i, [](const wstring&)
|
||||||
{
|
{
|
||||||
}); // this will evaluate and thus execute the action
|
}); // this will evaluate and thus execute the action
|
||||||
|
@ -525,7 +535,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
|
||||||
}
|
}
|
||||||
fprintf(stderr, "COMPLETED\n"), fflush(stderr);
|
fprintf(stderr, "COMPLETED\n"), fflush(stderr);
|
||||||
|
|
||||||
delete g_mpi;
|
MPIWrapper::DeleteInstance();
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -546,12 +556,10 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
|
||||||
ConfigArray command = config(L"command", "train");
|
ConfigArray command = config(L"command", "train");
|
||||||
|
|
||||||
// paralleltrain training
|
// paralleltrain training
|
||||||
g_mpi = nullptr;
|
shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> mpi;
|
||||||
bool paralleltrain = config(L"parallelTrain", "false");
|
bool paralleltrain = config(L"parallelTrain", "false");
|
||||||
if (paralleltrain)
|
if (paralleltrain)
|
||||||
{
|
mpi = MPIWrapper::GetInstance(true /*create*/);
|
||||||
g_mpi = new MPIWrapper();
|
|
||||||
}
|
|
||||||
|
|
||||||
g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false);
|
g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false);
|
||||||
|
|
||||||
|
@ -569,7 +577,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
|
||||||
if (paralleltrain)
|
if (paralleltrain)
|
||||||
{
|
{
|
||||||
std::wostringstream oss;
|
std::wostringstream oss;
|
||||||
oss << g_mpi->CurrentNodeRank();
|
oss << mpi->CurrentNodeRank();
|
||||||
logpath += L"rank" + oss.str();
|
logpath += L"rank" + oss.str();
|
||||||
}
|
}
|
||||||
RedirectStdErr(logpath);
|
RedirectStdErr(logpath);
|
||||||
|
@ -616,9 +624,9 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
|
||||||
|
|
||||||
fprintf(stderr, "\nPrecision = \"%s\"\n", type.c_str());
|
fprintf(stderr, "\nPrecision = \"%s\"\n", type.c_str());
|
||||||
if (type == "float")
|
if (type == "float")
|
||||||
DoCommands<float>(config);
|
DoCommands<float>(config, mpi);
|
||||||
else if (type == "double")
|
else if (type == "double")
|
||||||
DoCommands<double>(config);
|
DoCommands<double>(config, mpi);
|
||||||
else
|
else
|
||||||
RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str());
|
RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str());
|
||||||
|
|
||||||
|
@ -631,7 +639,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
|
||||||
}
|
}
|
||||||
fprintf(stderr, "COMPLETED\n"), fflush(stderr);
|
fprintf(stderr, "COMPLETED\n"), fflush(stderr);
|
||||||
|
|
||||||
delete g_mpi;
|
MPIWrapper::DeleteInstance();
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
|
@ -49,7 +50,7 @@ static int operator||(int rc, const MpiFail &what)
|
||||||
RuntimeError("%s", what.c_str());
|
RuntimeError("%s", what.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
class MPIWrapper
|
class MPIWrapper : public std::enable_shared_from_this<MPIWrapper>
|
||||||
{
|
{
|
||||||
int m_myRank;
|
int m_myRank;
|
||||||
int m_numMPINodes;
|
int m_numMPINodes;
|
||||||
|
@ -58,6 +59,8 @@ class MPIWrapper
|
||||||
// MPI communicator that reflects the current subset selection
|
// MPI communicator that reflects the current subset selection
|
||||||
MPI_Comm m_currentComm;
|
MPI_Comm m_currentComm;
|
||||||
|
|
||||||
|
static std::shared_ptr<MPIWrapper> s_mpi;
|
||||||
|
|
||||||
// MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that)
|
// MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that)
|
||||||
int MPI_Init_DL()
|
int MPI_Init_DL()
|
||||||
{
|
{
|
||||||
|
@ -100,7 +103,6 @@ class MPIWrapper
|
||||||
static int s_myRank;
|
static int s_myRank;
|
||||||
static void MPIWorkaroundAtExit()
|
static void MPIWorkaroundAtExit()
|
||||||
{
|
{
|
||||||
// Note: we can't use g_mpi, since MPI stack is already down at this point
|
|
||||||
Sleep(s_myRank * 50);
|
Sleep(s_myRank * 50);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,6 +157,7 @@ public:
|
||||||
MPI_Finalize();
|
MPI_Finalize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
void Ping(const char *msg) const
|
void Ping(const char *msg) const
|
||||||
{
|
{
|
||||||
#undef USE2NDCOMM
|
#undef USE2NDCOMM
|
||||||
|
@ -218,6 +221,28 @@ public:
|
||||||
Ping("requestnodes (after change)");
|
Ping("requestnodes (after change)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
static std::shared_ptr<MPIWrapper> GetInstance(bool create = false)
|
||||||
|
{
|
||||||
|
static bool initialized = false;
|
||||||
|
if (create)
|
||||||
|
{
|
||||||
|
if (initialized)
|
||||||
|
LogicError("Creating MPIWrapper instance after a GetInstance call has been already made!");
|
||||||
|
else
|
||||||
|
s_mpi = std::make_shared<MPIWrapper>();
|
||||||
|
}
|
||||||
|
|
||||||
|
initialized = true;
|
||||||
|
return s_mpi;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void DeleteInstance()
|
||||||
|
{
|
||||||
|
s_mpi = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
MPI_Comm Communicator() const
|
MPI_Comm Communicator() const
|
||||||
{
|
{
|
||||||
return m_currentComm;
|
return m_currentComm;
|
||||||
|
@ -312,8 +337,5 @@ public:
|
||||||
MPI_Barrier(m_currentComm) || MpiFail("waitall: MPI_Barrier");
|
MPI_Barrier(m_currentComm) || MpiFail("waitall: MPI_Barrier");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;
|
}}}
|
||||||
|
|
|
@ -2,3 +2,4 @@
|
||||||
#include "Include/MPIWrapper.h"
|
#include "Include/MPIWrapper.h"
|
||||||
|
|
||||||
int Microsoft::MSR::CNTK::MPIWrapper::s_myRank = -1;
|
int Microsoft::MSR::CNTK::MPIWrapper::s_myRank = -1;
|
||||||
|
std::shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> Microsoft::MSR::CNTK::MPIWrapper::s_mpi = nullptr;
|
||||||
|
|
|
@ -86,17 +86,11 @@ void ComputationNetwork::SaveEdited(const wstring& fileName, const FileOptions f
|
||||||
void ComputationNetwork::Save(const wstring& fileName, const FileOptions fileFormat) const
|
void ComputationNetwork::Save(const wstring& fileName, const FileOptions fileFormat) const
|
||||||
{
|
{
|
||||||
VerifyIsCompiled("Save");
|
VerifyIsCompiled("Save");
|
||||||
// In case of parallel training only the main node should we saving the model to prevent
|
// Saving into temporary file and then renaming it to the requested fileName
|
||||||
// the parallel training nodes from colliding to write the same file
|
// This is a standard trick to avoid havign corrupted model files if process dies during writing
|
||||||
// TODO: This does not belong here.
|
wstring tmpFileName = fileName + L".tmp";
|
||||||
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
|
SaveToFileImpl(tmpFileName, fileFormat);
|
||||||
{
|
renameOrDie(tmpFileName, fileName);
|
||||||
// Saving into temporary file and then renaming it to the requested fileName
|
|
||||||
// This is a standard trick to avoid havign corrupted model files if process dies during writing
|
|
||||||
wstring tmpFileName = fileName + L".tmp";
|
|
||||||
SaveToFileImpl(tmpFileName, fileFormat);
|
|
||||||
renameOrDie(tmpFileName, fileName);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: how does the file distinguish float vs double nodes?
|
// TODO: how does the file distinguish float vs double nodes?
|
||||||
|
|
|
@ -17,9 +17,6 @@
|
||||||
#include "BestGpu.h"
|
#include "BestGpu.h"
|
||||||
#include "MPIWrapper.h"
|
#include "MPIWrapper.h"
|
||||||
|
|
||||||
// TODO: Get rid of this global
|
|
||||||
Microsoft::MSR::CNTK::MPIWrapper* g_mpi = nullptr;
|
|
||||||
|
|
||||||
// TODO: Temporary mechanism to enable memory sharing for
|
// TODO: Temporary mechanism to enable memory sharing for
|
||||||
// node output value matrices. This will go away when the
|
// node output value matrices. This will go away when the
|
||||||
// sharing is ready to be enabled by default
|
// sharing is ready to be enabled by default
|
||||||
|
|
|
@ -30,7 +30,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
bool useDistributedMBReading,
|
bool useDistributedMBReading,
|
||||||
bool useParallelTrain,
|
bool useParallelTrain,
|
||||||
StreamMinibatchInputs& inputMatrices,
|
StreamMinibatchInputs& inputMatrices,
|
||||||
size_t& actualMBSize)
|
size_t& actualMBSize,
|
||||||
|
const std::shared_ptr<MPIWrapper>& mpi)
|
||||||
{
|
{
|
||||||
auto pMBLayout = net->GetMBLayoutPtr();
|
auto pMBLayout = net->GetMBLayoutPtr();
|
||||||
// Reading consists of a sequence of Reader API calls:
|
// Reading consists of a sequence of Reader API calls:
|
||||||
|
@ -66,7 +67,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
// decimate if needed. Decimation happens in-place.
|
// decimate if needed. Decimation happens in-place.
|
||||||
if (!useDistributedMBReading && useParallelTrain)
|
if (!useDistributedMBReading && useParallelTrain)
|
||||||
DecimateMinibatchInPlace<ElemType>(inputMatrices, g_mpi->NumNodesInUse(), g_mpi->CurrentNodeRank(), net->GetMBLayoutPtr());
|
DecimateMinibatchInPlace<ElemType>(inputMatrices, mpi->NumNodesInUse(), mpi->CurrentNodeRank(), net->GetMBLayoutPtr());
|
||||||
|
|
||||||
// reader will have resized input node's m_value directly. Nodes must be notified to do necessary internal state updates from that.
|
// reader will have resized input node's m_value directly. Nodes must be notified to do necessary internal state updates from that.
|
||||||
// TODO: This is a stopgap. SGD will at some point change from sets of matrices to sets of nodes. Then this will become much simpler.
|
// TODO: This is a stopgap. SGD will at some point change from sets of matrices to sets of nodes. Then this will become much simpler.
|
||||||
|
|
|
@ -9,7 +9,7 @@ template <class ElemType>
|
||||||
class IDistGradAggregator
|
class IDistGradAggregator
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
IDistGradAggregator(MPIWrapper* mpi)
|
IDistGradAggregator(const std::shared_ptr<MPIWrapper>& mpi)
|
||||||
: m_mpi(mpi)
|
: m_mpi(mpi)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
MPIWrapper* m_mpi;
|
std::shared_ptr<MPIWrapper> m_mpi;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define UsingIDistGradAggregatorMembers \
|
#define UsingIDistGradAggregatorMembers \
|
||||||
|
|
|
@ -102,7 +102,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
{
|
{
|
||||||
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
|
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
|
||||||
public:
|
public:
|
||||||
IMASGD(MPIWrapper* pMPI, size_t perfReportFreq)
|
IMASGD(const std::shared_ptr<MPIWrapper>& pMPI, size_t perfReportFreq)
|
||||||
:m_MAworkerStatus(pMPI->NumNodesInUse(), MAWorkerStatus::NOTSTARTED),
|
:m_MAworkerStatus(pMPI->NumNodesInUse(), MAWorkerStatus::NOTSTARTED),
|
||||||
m_numSyncPerformed(0),
|
m_numSyncPerformed(0),
|
||||||
m_numWorkers(pMPI->NumNodesInUse()),
|
m_numWorkers(pMPI->NumNodesInUse()),
|
||||||
|
@ -120,7 +120,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
{
|
{
|
||||||
m_MAworkerStatus.resize(m_numWorkers);
|
m_MAworkerStatus.resize(m_numWorkers);
|
||||||
std::fill(m_MAworkerStatus.begin(), m_MAworkerStatus.end(), MAWorkerStatus::DataProcessing);
|
std::fill(m_MAworkerStatus.begin(), m_MAworkerStatus.end(), MAWorkerStatus::DataProcessing);
|
||||||
g_mpi->WaitAll();
|
m_pMPI->WaitAll();
|
||||||
m_perfReporter.OnEpochStart();
|
m_perfReporter.OnEpochStart();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,8 +286,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
size_t m_numWorkers;
|
size_t m_numWorkers;
|
||||||
size_t m_myRank;
|
size_t m_myRank;
|
||||||
MASGDPerfStats m_perfReporter;
|
MASGDPerfStats m_perfReporter;
|
||||||
MPIWrapper* m_pMPI; // TODO: to use shared_ptr in the future
|
std::shared_ptr<MPIWrapper> m_pMPI;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -300,7 +299,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
using Base::DownCast;
|
using Base::DownCast;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
BasicModelAveragingSGD(MPIWrapper* pMPI, size_t reportFreq)
|
BasicModelAveragingSGD(const std::shared_ptr<MPIWrapper>& pMPI, size_t reportFreq)
|
||||||
:Base(pMPI, reportFreq)
|
:Base(pMPI, reportFreq)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
@ -327,7 +326,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
if (nTotalSamples <= 0)
|
if (nTotalSamples <= 0)
|
||||||
{
|
{
|
||||||
// prepare for overflow
|
// prepare for overflow
|
||||||
factor = 1.0f / g_mpi->NumNodesInUse();
|
factor = 1.0f / m_pMPI->NumNodesInUse();
|
||||||
totalSamplesProcessed = samplesSinceLastSync * m_pMPI->NumNodesInUse();
|
totalSamplesProcessed = samplesSinceLastSync * m_pMPI->NumNodesInUse();
|
||||||
// give an estimated one
|
// give an estimated one
|
||||||
}
|
}
|
||||||
|
|
|
@ -261,11 +261,11 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
prevLearnRates[i] = -1.0;
|
prevLearnRates[i] = -1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_parallelizationMethod == ParallelizationMethod::DataParallelSGD)
|
if (GetParallelizationMethod() == ParallelizationMethod::DataParallelSGD)
|
||||||
{
|
{
|
||||||
InitDistGradAgg(evaluationNodes.size(), m_traceLevel);
|
InitDistGradAgg(evaluationNodes.size(), m_traceLevel);
|
||||||
}
|
}
|
||||||
else if (m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD)
|
else if (GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD)
|
||||||
{
|
{
|
||||||
InitModelAggregationHandler(m_syncStatsTrace);
|
InitModelAggregationHandler(m_syncStatsTrace);
|
||||||
}
|
}
|
||||||
|
@ -278,12 +278,15 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
{
|
{
|
||||||
// Synchronize all ranks before writing the model to ensure that
|
// Synchronize all ranks before writing the model to ensure that
|
||||||
// everyone is done loading the model
|
// everyone is done loading the model
|
||||||
if (g_mpi != nullptr)
|
if (m_mpi != nullptr)
|
||||||
{
|
{
|
||||||
g_mpi->WaitAll();
|
m_mpi->WaitAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
net->Save(GetModelNameForEpoch(int(startEpoch) - 1));
|
// In case of parallel training only the main node should we saving the model to prevent
|
||||||
|
// the parallel training nodes from colliding to write the same file
|
||||||
|
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
|
||||||
|
net->Save(GetModelNameForEpoch(int(startEpoch) - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool learnRateInitialized = false;
|
bool learnRateInitialized = false;
|
||||||
|
@ -332,9 +335,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
{
|
{
|
||||||
// Synchronize all ranks before proceeding to ensure that
|
// Synchronize all ranks before proceeding to ensure that
|
||||||
// rank 0 has finished writing the previous model file
|
// rank 0 has finished writing the previous model file
|
||||||
if (g_mpi != nullptr)
|
if (m_mpi != nullptr)
|
||||||
{
|
{
|
||||||
g_mpi->WaitAll();
|
m_mpi->WaitAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
Timer timer;
|
Timer timer;
|
||||||
|
@ -384,7 +387,10 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
i + 1, learnRatePerSample, m_minLearnRate);
|
i + 1, learnRatePerSample, m_minLearnRate);
|
||||||
if (m_autoLearnRateSearchType != LearningRateSearchAlgorithm::None)
|
if (m_autoLearnRateSearchType != LearningRateSearchAlgorithm::None)
|
||||||
{
|
{
|
||||||
net->Save(m_modelPath);
|
// In case of parallel training only the main node should we saving the model to prevent
|
||||||
|
// the parallel training nodes from colliding to write the same file
|
||||||
|
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
|
||||||
|
net->Save(m_modelPath);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -501,7 +507,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
|
|
||||||
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
|
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
|
||||||
{
|
{
|
||||||
SimpleEvaluator<ElemType> evalforvalidation(net, g_mpi != nullptr);
|
SimpleEvaluator<ElemType> evalforvalidation(net, m_mpi);
|
||||||
vector<wstring> cvSetTrainAndEvalNodes;
|
vector<wstring> cvSetTrainAndEvalNodes;
|
||||||
if (criterionNodes.size() > 0)
|
if (criterionNodes.size() > 0)
|
||||||
{
|
{
|
||||||
|
@ -535,10 +541,10 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast epochCriterion to make sure each processor will have the same learning rate schedule
|
// broadcast epochCriterion to make sure each processor will have the same learning rate schedule
|
||||||
if ((m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) && (g_mpi->NumNodesInUse() > 1))
|
if ((GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD) && (m_mpi->NumNodesInUse() > 1))
|
||||||
{
|
{
|
||||||
g_mpi->Bcast(&epochCriterion, 1, g_mpi->MainNodeRank());
|
m_mpi->Bcast(&epochCriterion, 1, m_mpi->MainNodeRank());
|
||||||
g_mpi->Bcast(&lrControlCriterion, 1, g_mpi->MainNodeRank());
|
m_mpi->Bcast(&lrControlCriterion, 1, m_mpi->MainNodeRank());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadedPrevModel = false;
|
bool loadedPrevModel = false;
|
||||||
|
@ -587,7 +593,10 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
net->Save(GetModelNameForEpoch(i, true));
|
// In case of parallel training only the main node should we saving the model to prevent
|
||||||
|
// the parallel training nodes from colliding to write the same file
|
||||||
|
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
|
||||||
|
net->Save(GetModelNameForEpoch(i, true));
|
||||||
|
|
||||||
fprintf(stderr, "Finished training and saved final model\n\n");
|
fprintf(stderr, "Finished training and saved final model\n\n");
|
||||||
break;
|
break;
|
||||||
|
@ -634,13 +643,13 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
// Synchronize all ranks before proceeding to ensure that
|
// Synchronize all ranks before proceeding to ensure that
|
||||||
// nobody tries reading the checkpoint file at the same time
|
// nobody tries reading the checkpoint file at the same time
|
||||||
// as rank 0 deleting it below
|
// as rank 0 deleting it below
|
||||||
if (g_mpi != nullptr)
|
if (m_mpi != nullptr)
|
||||||
{
|
{
|
||||||
g_mpi->WaitAll();
|
m_mpi->WaitAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist model and check-point info
|
// persist model and check-point info
|
||||||
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
|
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
|
||||||
{
|
{
|
||||||
SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion, chosenMinibatchSize);
|
SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion, chosenMinibatchSize);
|
||||||
auto modelName = GetModelNameForEpoch(i);
|
auto modelName = GetModelNameForEpoch(i);
|
||||||
|
@ -677,9 +686,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
|
|
||||||
// Synchronize all ranks before proceeding to ensure that
|
// Synchronize all ranks before proceeding to ensure that
|
||||||
// rank 0 has finished writing the model file
|
// rank 0 has finished writing the model file
|
||||||
if (g_mpi != nullptr)
|
if (m_mpi != nullptr)
|
||||||
{
|
{
|
||||||
g_mpi->WaitAll();
|
m_mpi->WaitAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
// progress tracing for compute cluster management
|
// progress tracing for compute cluster management
|
||||||
|
@ -746,9 +755,9 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
localEpochCriterion.SetValue(0);
|
localEpochCriterion.SetValue(0);
|
||||||
localEpochEvalErrors.SetValue(0);
|
localEpochEvalErrors.SetValue(0);
|
||||||
|
|
||||||
bool useGradientAggregation = ((m_parallelizationMethod == ParallelizationMethod::DataParallelSGD) &&
|
bool useGradientAggregation = ((GetParallelizationMethod() == ParallelizationMethod::DataParallelSGD) &&
|
||||||
(epochNumber >= m_parallelizationStartEpochNum));
|
(epochNumber >= m_parallelizationStartEpochNum));
|
||||||
bool useModelAveraging = ((m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) &&
|
bool useModelAveraging = ((GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD) &&
|
||||||
(epochNumber >= m_parallelizationStartEpochNum));
|
(epochNumber >= m_parallelizationStartEpochNum));
|
||||||
bool useParallelTrain = useGradientAggregation || useModelAveraging;
|
bool useParallelTrain = useGradientAggregation || useModelAveraging;
|
||||||
|
|
||||||
|
@ -778,8 +787,8 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
trainSetDataReader->SupportsDistributedMBRead();
|
trainSetDataReader->SupportsDistributedMBRead();
|
||||||
if (useDistributedMBReading)
|
if (useDistributedMBReading)
|
||||||
{
|
{
|
||||||
trainSetDataReader->StartDistributedMinibatchLoop(tunedMBSize, epochNumber, g_mpi->CurrentNodeRank(),
|
trainSetDataReader->StartDistributedMinibatchLoop(tunedMBSize, epochNumber, m_mpi->CurrentNodeRank(),
|
||||||
g_mpi->NumNodesInUse(), epochSize);
|
m_mpi->NumNodesInUse(), epochSize);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -813,7 +822,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
if (useGradientAggregation)
|
if (useGradientAggregation)
|
||||||
{
|
{
|
||||||
fprintf(stderr, ", DataParallelSGD training (MyRank = %d, NumNodes = %d, NumGradientBits = %d)",
|
fprintf(stderr, ", DataParallelSGD training (MyRank = %d, NumNodes = %d, NumGradientBits = %d)",
|
||||||
(int) g_mpi->CurrentNodeRank(), (int) g_mpi->NumNodesInUse(), (int) m_numGradientBits);
|
(int) m_mpi->CurrentNodeRank(), (int) m_mpi->NumNodesInUse(), (int) m_numGradientBits);
|
||||||
if (m_bufferedAsyncGradientAggregation)
|
if (m_bufferedAsyncGradientAggregation)
|
||||||
{
|
{
|
||||||
fprintf(stderr, ", BufferedAsyncGradientAggregation is ENABLED");
|
fprintf(stderr, ", BufferedAsyncGradientAggregation is ENABLED");
|
||||||
|
@ -844,7 +853,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
|
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
|
||||||
size_t actualMBSize = 0;
|
size_t actualMBSize = 0;
|
||||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
|
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
|
||||||
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize);
|
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);
|
||||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
|
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
|
||||||
break; // end of epoch
|
break; // end of epoch
|
||||||
|
|
||||||
|
@ -1194,11 +1203,11 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
}
|
}
|
||||||
|
|
||||||
// in case of model averaging, do one more final aggregation of criteria
|
// in case of model averaging, do one more final aggregation of criteria
|
||||||
if (useModelAveraging && (g_mpi->NumNodesInUse() > 1))
|
if (useModelAveraging && (m_mpi->NumNodesInUse() > 1))
|
||||||
{
|
{
|
||||||
// 1. total epoch samples processed by all workers
|
// 1. total epoch samples processed by all workers
|
||||||
size_t totalEpochSamplesOfAllWorkers = totalEpochSamples;
|
size_t totalEpochSamplesOfAllWorkers = totalEpochSamples;
|
||||||
g_mpi->AllReduce(&totalEpochSamplesOfAllWorkers, 1);
|
m_mpi->AllReduce(&totalEpochSamplesOfAllWorkers, 1);
|
||||||
totalSamplesSeen += totalEpochSamplesOfAllWorkers;
|
totalSamplesSeen += totalEpochSamplesOfAllWorkers;
|
||||||
|
|
||||||
// 2. criterion and EvalErrors
|
// 2. criterion and EvalErrors
|
||||||
|
@ -1211,8 +1220,8 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
epochEvalErrors[i] = localEpochEvalErrors(0, i);
|
epochEvalErrors[i] = localEpochEvalErrors(0, i);
|
||||||
}
|
}
|
||||||
// merge epochCriterion and epochEvalErrors over nodes
|
// merge epochCriterion and epochEvalErrors over nodes
|
||||||
g_mpi->AllReduce(&epochCriterion, 1);
|
m_mpi->AllReduce(&epochCriterion, 1);
|
||||||
g_mpi->AllReduce(epochEvalErrors);
|
m_mpi->AllReduce(epochEvalErrors);
|
||||||
|
|
||||||
// 3. modify return value
|
// 3. modify return value
|
||||||
totalEpochSamples = totalEpochSamplesOfAllWorkers;
|
totalEpochSamples = totalEpochSamplesOfAllWorkers;
|
||||||
|
@ -1298,7 +1307,7 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
|
||||||
const size_t numIterationsBeforePrintingProgress = 100;
|
const size_t numIterationsBeforePrintingProgress = 100;
|
||||||
size_t numItersSinceLastPrintOfProgress = 0;
|
size_t numItersSinceLastPrintOfProgress = 0;
|
||||||
size_t actualMBSizeDummy;
|
size_t actualMBSizeDummy;
|
||||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy))
|
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi))
|
||||||
{
|
{
|
||||||
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
|
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
|
||||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||||
|
@ -1821,19 +1830,19 @@ int SGD<ElemType>::SGDTrace(FILE* __restrict __stream, const char* __restrict __
|
||||||
template <class ElemType>
|
template <class ElemType>
|
||||||
void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int traceLevel)
|
void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int traceLevel)
|
||||||
{
|
{
|
||||||
if (m_parallelizationMethod == ParallelizationMethod::DataParallelSGD)
|
if (GetParallelizationMethod() == ParallelizationMethod::DataParallelSGD)
|
||||||
{
|
{
|
||||||
if (m_distGradAgg == nullptr)
|
if (m_distGradAgg == nullptr)
|
||||||
{
|
{
|
||||||
#ifdef QUANTIZED_GRADIENT_AGGREGATION
|
#ifdef QUANTIZED_GRADIENT_AGGREGATION
|
||||||
m_distGradAgg = new AllReduceDistGradAggregator<ElemType>(g_mpi, m_numGradientBits, m_zeroThresholdFor1Bit, true /*useQuantizationForSelfStripe*/, m_bufferedAsyncGradientAggregation, traceLevel, m_syncStatsTrace);
|
m_distGradAgg = new AllReduceDistGradAggregator<ElemType>(m_mpi, m_numGradientBits, m_zeroThresholdFor1Bit, true /*useQuantizationForSelfStripe*/, m_bufferedAsyncGradientAggregation, traceLevel, m_syncStatsTrace);
|
||||||
#else
|
#else
|
||||||
if (m_numGradientBits != (8 * sizeof(ElemType)))
|
if (m_numGradientBits != (8 * sizeof(ElemType)))
|
||||||
{
|
{
|
||||||
RuntimeError("Gradient quantization is unsupported in CNTK binaries built without quantized gradient aggregation support!");
|
RuntimeError("Gradient quantization is unsupported in CNTK binaries built without quantized gradient aggregation support!");
|
||||||
}
|
}
|
||||||
|
|
||||||
m_distGradAgg = new SimpleDistGradAggregator<ElemType>(g_mpi, m_bufferedAsyncGradientAggregation, m_syncStatsTrace);
|
m_distGradAgg = new SimpleDistGradAggregator<ElemType>(m_mpi, m_bufferedAsyncGradientAggregation, m_syncStatsTrace);
|
||||||
#endif // !QUANTIZED_GRADIENT_AGGREGATION
|
#endif // !QUANTIZED_GRADIENT_AGGREGATION
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1847,12 +1856,12 @@ void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int traceLevel)
|
||||||
template <class ElemType>
|
template <class ElemType>
|
||||||
void SGD<ElemType>::InitModelAggregationHandler(int traceLevel)
|
void SGD<ElemType>::InitModelAggregationHandler(int traceLevel)
|
||||||
{
|
{
|
||||||
if (m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD)
|
if (GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD)
|
||||||
{
|
{
|
||||||
#ifndef BLOCKWISE_MODEL_UPDATE_FILTERING
|
#ifndef BLOCKWISE_MODEL_UPDATE_FILTERING
|
||||||
if (!m_pMASGDHelper)
|
if (!m_pMASGDHelper)
|
||||||
{
|
{
|
||||||
m_pMASGDHelper = make_shared<BasicModelAveragingSGD<ElemType>>(g_mpi, traceLevel);
|
m_pMASGDHelper = make_shared<BasicModelAveragingSGD<ElemType>>(m_mpi, traceLevel);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
@ -2011,7 +2020,7 @@ void SGD<ElemType>::SaveCheckPointInfo(const size_t epoch, const size_t totalSam
|
||||||
{
|
{
|
||||||
// In case of parallel training only the main node should we saving the checkpoint to prevent
|
// In case of parallel training only the main node should we saving the checkpoint to prevent
|
||||||
// the parallel training nodes from colliding to write the same file
|
// the parallel training nodes from colliding to write the same file
|
||||||
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
|
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
|
||||||
{
|
{
|
||||||
wstring checkPointFileName = GetCheckPointFileNameForEpoch(int(epoch));
|
wstring checkPointFileName = GetCheckPointFileNameForEpoch(int(epoch));
|
||||||
// Saving into temporary file and then renaming it to the checkPointFileName
|
// Saving into temporary file and then renaming it to the checkPointFileName
|
||||||
|
@ -2538,7 +2547,7 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
||||||
m_parallelizationStartEpochNum = 0;
|
m_parallelizationStartEpochNum = 0;
|
||||||
m_nFramesBetweenMASync = 40000; // default 40k frames
|
m_nFramesBetweenMASync = 40000; // default 40k frames
|
||||||
|
|
||||||
if ((g_mpi != nullptr) && configSGD.Exists(L"ParallelTrain"))
|
if (configSGD.Exists(L"ParallelTrain"))
|
||||||
{
|
{
|
||||||
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
|
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
|
||||||
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));
|
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));
|
||||||
|
|
|
@ -143,6 +143,14 @@ protected:
|
||||||
return pow(m_momentumParam[epoch], 1.0 / FixUpEffectiveMBSize(m_momentumSpecifiedForMBSize[epoch], numParallelSequences));
|
return pow(m_momentumParam[epoch], 1.0 / FixUpEffectiveMBSize(m_momentumSpecifiedForMBSize[epoch], numParallelSequences));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParallelizationMethod GetParallelizationMethod() const
|
||||||
|
{
|
||||||
|
if (m_mpi == nullptr)
|
||||||
|
return ParallelizationMethod::None;
|
||||||
|
|
||||||
|
return m_parallelizationMethod;
|
||||||
|
}
|
||||||
|
|
||||||
// only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1
|
// only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1
|
||||||
// bool m_needToNormalizeLRByParallUtterance; // TODO: should go away
|
// bool m_needToNormalizeLRByParallUtterance; // TODO: should go away
|
||||||
// bool m_needToNormalizeMomentumByParallUtterance;
|
// bool m_needToNormalizeMomentumByParallUtterance;
|
||||||
|
@ -228,6 +236,8 @@ protected:
|
||||||
bool m_useAllDataForPreComputedNode;
|
bool m_useAllDataForPreComputedNode;
|
||||||
|
|
||||||
// Parallel training
|
// Parallel training
|
||||||
|
std::shared_ptr<MPIWrapper> m_mpi;
|
||||||
|
|
||||||
ParallelizationMethod m_parallelizationMethod;
|
ParallelizationMethod m_parallelizationMethod;
|
||||||
bool m_enableDistributedMBReading;
|
bool m_enableDistributedMBReading;
|
||||||
int m_parallelizationStartEpochNum;
|
int m_parallelizationStartEpochNum;
|
||||||
|
@ -303,6 +313,14 @@ public:
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitMPI(const std::shared_ptr<MPIWrapper>& mpi)
|
||||||
|
{
|
||||||
|
m_mpi = mpi;
|
||||||
|
|
||||||
|
if (m_mpi == nullptr)
|
||||||
|
m_parallelizationMethod = ParallelizationMethod::None;
|
||||||
|
}
|
||||||
|
|
||||||
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
|
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
|
||||||
IDataReader* trainSetDataReader,
|
IDataReader* trainSetDataReader,
|
||||||
IDataReader* validationSetDataReader,
|
IDataReader* validationSetDataReader,
|
||||||
|
|
|
@ -15,7 +15,7 @@ class SimpleDistGradAggregator : public IDistGradAggregator<ElemType>
|
||||||
UsingIDistGradAggregatorMembers;
|
UsingIDistGradAggregatorMembers;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SimpleDistGradAggregator(MPIWrapper* mpi, bool useAsyncAggregation, int syncStatsTrace)
|
SimpleDistGradAggregator(const std::shared_ptr<MPIWrapper>& mpi, bool useAsyncAggregation, int syncStatsTrace)
|
||||||
: IDistGradAggregator<ElemType>(mpi), m_useAsyncAggregation(useAsyncAggregation), m_currentEpochNumber(-1), m_bufferedGradHeader(nullptr), m_syncStatsTrace(syncStatsTrace), m_iterationCount(0)
|
: IDistGradAggregator<ElemType>(mpi), m_useAsyncAggregation(useAsyncAggregation), m_currentEpochNumber(-1), m_bufferedGradHeader(nullptr), m_syncStatsTrace(syncStatsTrace), m_iterationCount(0)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,14 +31,14 @@ template <class ElemType>
|
||||||
class SimpleEvaluator
|
class SimpleEvaluator
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SimpleEvaluator(ComputationNetworkPtr net, const bool parallelRun, const size_t numMBsToShowResult = 100, const int traceLevel = 0, const size_t maxSamplesInRAM = SIZE_MAX,
|
SimpleEvaluator(ComputationNetworkPtr net, const std::shared_ptr<MPIWrapper>& mpi, const size_t numMBsToShowResult = 100, const int traceLevel = 0, const size_t maxSamplesInRAM = SIZE_MAX,
|
||||||
const size_t numSubminiBatches = 1)
|
const size_t numSubminiBatches = 1)
|
||||||
: m_net(net),
|
: m_net(net),
|
||||||
m_numMBsToShowResult(numMBsToShowResult),
|
m_numMBsToShowResult(numMBsToShowResult),
|
||||||
m_traceLevel(traceLevel),
|
m_traceLevel(traceLevel),
|
||||||
m_maxSamplesInRAM(maxSamplesInRAM),
|
m_maxSamplesInRAM(maxSamplesInRAM),
|
||||||
m_numSubminiBatches(numSubminiBatches),
|
m_numSubminiBatches(numSubminiBatches),
|
||||||
m_parallelRun(parallelRun),
|
m_mpi(mpi),
|
||||||
m_distGradAgg(nullptr),
|
m_distGradAgg(nullptr),
|
||||||
m_gradHeader(nullptr)
|
m_gradHeader(nullptr)
|
||||||
{
|
{
|
||||||
|
@ -123,7 +123,7 @@ public:
|
||||||
|
|
||||||
const size_t numIterationsBeforePrintingProgress = 100;
|
const size_t numIterationsBeforePrintingProgress = 100;
|
||||||
size_t numItersSinceLastPrintOfProgress = 0;
|
size_t numItersSinceLastPrintOfProgress = 0;
|
||||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, false, m_parallelRun, inputMatrices, actualMBSize))
|
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, false, m_mpi != nullptr, inputMatrices, actualMBSize, m_mpi))
|
||||||
{
|
{
|
||||||
size_t actualNumSubminibatches = numSubminibatchesNeeded <= 1 ? 1 : smbDispatcher.GetMinibatchIntoCache(*dataReader, *m_net, inputMatrices, numSubminibatchesNeeded);
|
size_t actualNumSubminibatches = numSubminibatchesNeeded <= 1 ? 1 : smbDispatcher.GetMinibatchIntoCache(*dataReader, *m_net, inputMatrices, numSubminibatchesNeeded);
|
||||||
for (size_t ismb = 0; ismb < actualNumSubminibatches; ismb++)
|
for (size_t ismb = 0; ismb < actualNumSubminibatches; ismb++)
|
||||||
|
@ -148,12 +148,12 @@ public:
|
||||||
|
|
||||||
size_t numSamplesWithLabel = m_net->GetNumSamplesWithLabel(actualMBSize);
|
size_t numSamplesWithLabel = m_net->GetNumSamplesWithLabel(actualMBSize);
|
||||||
size_t aggregateNumSamplesWithLabel = numSamplesWithLabel;
|
size_t aggregateNumSamplesWithLabel = numSamplesWithLabel;
|
||||||
if (m_parallelRun)
|
if (m_mpi != nullptr)
|
||||||
{
|
{
|
||||||
if (m_gradHeader == nullptr)
|
if (m_gradHeader == nullptr)
|
||||||
{
|
{
|
||||||
m_gradHeader = DistGradHeader::Create(evalNodes.size());
|
m_gradHeader = DistGradHeader::Create(evalNodes.size());
|
||||||
m_distGradAgg = make_shared<SimpleDistGradAggregator<ElemType>>(g_mpi, false, m_traceLevel);
|
m_distGradAgg = make_shared<SimpleDistGradAggregator<ElemType>>(m_mpi, false, m_traceLevel);
|
||||||
}
|
}
|
||||||
|
|
||||||
m_gradHeader->numEvalNode = evalNodes.size();
|
m_gradHeader->numEvalNode = evalNodes.size();
|
||||||
|
@ -287,7 +287,7 @@ protected:
|
||||||
size_t m_numMBsToShowResult;
|
size_t m_numMBsToShowResult;
|
||||||
size_t m_maxSamplesInRAM;
|
size_t m_maxSamplesInRAM;
|
||||||
size_t m_numSubminiBatches;
|
size_t m_numSubminiBatches;
|
||||||
bool m_parallelRun;
|
std::shared_ptr<MPIWrapper> m_mpi;
|
||||||
|
|
||||||
shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
|
shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
|
||||||
struct DistGradHeader* m_gradHeader;
|
struct DistGradHeader* m_gradHeader;
|
||||||
|
|
|
@ -107,7 +107,7 @@ public:
|
||||||
const size_t numIterationsBeforePrintingProgress = 100;
|
const size_t numIterationsBeforePrintingProgress = 100;
|
||||||
size_t numItersSinceLastPrintOfProgress = 0;
|
size_t numItersSinceLastPrintOfProgress = 0;
|
||||||
size_t actualMBSize;
|
size_t actualMBSize;
|
||||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
|
||||||
{
|
{
|
||||||
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
||||||
|
|
||||||
|
@ -237,7 +237,7 @@ public:
|
||||||
size_t actualMBSize;
|
size_t actualMBSize;
|
||||||
const size_t numIterationsBeforePrintingProgress = 100;
|
const size_t numIterationsBeforePrintingProgress = 100;
|
||||||
size_t numItersSinceLastPrintOfProgress = 0;
|
size_t numItersSinceLastPrintOfProgress = 0;
|
||||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
|
||||||
{
|
{
|
||||||
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче