Bugfix: Disable parallelization for commands other than ones for train, test and cv

This commit is contained in:
Amit Agarwal 2016-03-12 00:04:52 -08:00
Родитель 81a83b7e4b
Коммит 877725d15e
16 изменённых файлов: 210 добавлений и 163 удалений

@ -1 +1 @@
Subproject commit f785679a6bd5cc089b138b3c6bcb68e4b1f345ae Subproject commit 41c1f55b9d5115c4dd051391f38eed8e93fb1860

Просмотреть файл

@ -59,8 +59,6 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100"); size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX); size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1); size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
//TODO: switch to a global parallel setting for both training and evaluation.
bool useParallel = config(L"parallelTrain", false);
ConfigArray evalNodeNames = config(L"evalNodeNames", ""); ConfigArray evalNodeNames = config(L"evalNodeNames", "");
vector<wstring> evalNodeNamesVector; vector<wstring> evalNodeNamesVector;
@ -71,7 +69,7 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath); auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);
SimpleEvaluator<ElemType> eval(net, useParallel, numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches); SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize); eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
} }
@ -120,8 +118,6 @@ void DoCrossValidate(const ConfigParameters& config)
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100"); size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX); size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1); size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
//TODO: switch to a global parallel setting for both training and evaluation.
bool useParallel = config(L"parallelTrain", false);
ConfigArray evalNodeNames = config(L"evalNodeNames", ""); ConfigArray evalNodeNames = config(L"evalNodeNames", "");
vector<wstring> evalNodeNamesVector; vector<wstring> evalNodeNamesVector;
@ -155,7 +151,7 @@ void DoCrossValidate(const ConfigParameters& config)
cvModels.push_back(cvModelPath); cvModels.push_back(cvModelPath);
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, cvModelPath); auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, cvModelPath);
SimpleEvaluator<ElemType> eval(net, useParallel, numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches); SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
fprintf(stderr, "model %ls --> \n", cvModelPath.c_str()); fprintf(stderr, "model %ls --> \n", cvModelPath.c_str());
auto evalErrors = eval.Evaluate(&cvDataReader, evalNodeNamesVector, mbSize[0], epochSize); auto evalErrors = eval.Evaluate(&cvDataReader, evalNodeNamesVector, mbSize[0], epochSize);

Просмотреть файл

@ -172,6 +172,7 @@ void DoTrain(const ConfigRecordType& config)
optimizer = make_shared<SGD<ElemType>>(configSGD); optimizer = make_shared<SGD<ElemType>>(configSGD);
} }
optimizer->InitMPI(MPIWrapper::GetInstance());
optimizer->Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode); optimizer->Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
} }
@ -241,6 +242,7 @@ void DoAdapt(const ConfigParameters& config)
SGD<ElemType> sgd(configSGD); SGD<ElemType> sgd(configSGD);
sgd.InitMPI(MPIWrapper::GetInstance());
sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode); sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
} }

Просмотреть файл

@ -56,9 +56,6 @@
#define let const auto #define let const auto
#endif #endif
// TODO: Get rid of these globals
Microsoft::MSR::CNTK::MPIWrapper* g_mpi = nullptr;
// TODO: Temporary mechanism to enable memory sharing for // TODO: Temporary mechanism to enable memory sharing for
// node output value matrices. This will go away when the // node output value matrices. This will go away when the
// sharing is ready to be enabled by default // sharing is ready to be enabled by default
@ -154,9 +151,13 @@ static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const Con
} }
} }
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
// be run in parallel across multiple ranks. Others should only run on rank 0
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest" };
// process the command // process the command
template <typename ElemType> template <typename ElemType>
void DoCommands(const ConfigParameters& config) void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mpi)
{ {
ConfigArray command = config(L"command", "train"); ConfigArray command = config(L"command", "train");
@ -197,7 +198,7 @@ void DoCommands(const ConfigParameters& config)
std::cerr << "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : " << fullTotalMaxEpochs << endl; std::cerr << "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : " << fullTotalMaxEpochs << endl;
// set up progress tracing for compute cluster management // set up progress tracing for compute cluster management
if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode())) if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
{ {
ProgressTracing::SetTracingFlag(); ProgressTracing::SetTracingFlag();
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
@ -212,7 +213,7 @@ void DoCommands(const ConfigParameters& config)
ConfigParameters commandParams(config(command[i])); ConfigParameters commandParams(config(command[i]));
ConfigArray action = commandParams("action", "train"); ConfigArray action = commandParams("action", "train");
if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode())) if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
{ {
ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to
} }
@ -231,6 +232,8 @@ void DoCommands(const ConfigParameters& config)
fprintf(stderr, "#%*s#\n", (int)(strlen(delim) - 2), ""); fprintf(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
fprintf(stderr, "%s\n\n", delim); fprintf(stderr, "%s\n\n", delim);
if ((mpi == nullptr) || (commandstoRunOnAllRanks.find(thisAction) != commandstoRunOnAllRanks.end()) || mpi->IsMainNode())
{
if (thisAction == "train" || thisAction == "trainRNN") if (thisAction == "train" || thisAction == "trainRNN")
{ {
std::cerr << "CNTKCommandTrainBegin: " + command[i] << endl; std::cerr << "CNTKCommandTrainBegin: " + command[i] << endl;
@ -294,11 +297,16 @@ void DoCommands(const ConfigParameters& config)
{ {
RuntimeError("unknown action: %s in command set: %s", thisAction.c_str(), command[i].c_str()); RuntimeError("unknown action: %s in command set: %s", thisAction.c_str(), command[i].c_str());
} }
}
fprintf(stderr, "\nAction \"%s\" complete.\n\n", thisAction.c_str()); fprintf(stderr, "\nAction \"%s\" complete.\n\n", thisAction.c_str());
NDLScript<ElemType> ndlScript; NDLScript<ElemType> ndlScript;
ndlScript.ClearGlobal(); // clear global macros between commands ndlScript.ClearGlobal(); // clear global macros between commands
// Synchronize all ranks before proceeding to next action/command
if (mpi != nullptr)
mpi->WaitAll();
} }
} }
} }
@ -459,10 +467,10 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
InvalidArgument("Legacy name 'type' no longer allowed. Use 'precision'."); InvalidArgument("Legacy name 'type' no longer allowed. Use 'precision'.");
// parallel training // parallel training
g_mpi = nullptr; shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> mpi;
bool paralleltrain = config(L"parallelTrain", false); bool paralleltrain = config(L"parallelTrain", false);
if (paralleltrain) if (paralleltrain)
g_mpi = new MPIWrapper(); mpi = MPIWrapper::GetInstance(true /*create*/);
g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false); g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false);
@ -476,7 +484,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
logpath += L".log"; // TODO: why do we need to append this here? logpath += L".log"; // TODO: why do we need to append this here?
if (paralleltrain) if (paralleltrain)
logpath += msra::strfun::wstrprintf(L"rank%d", (int) g_mpi->CurrentNodeRank()); logpath += msra::strfun::wstrprintf(L"rank%d", (int) mpi->CurrentNodeRank());
RedirectStdErr(logpath); RedirectStdErr(logpath);
fprintf(stderr, "%ls\n", startupMessage.c_str()); fprintf(stderr, "%ls\n", startupMessage.c_str());
@ -495,7 +503,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
bool progressTracing = config(L"progressTracing", false); bool progressTracing = config(L"progressTracing", false);
size_t fullTotalMaxEpochs = 1; // BUGBUG: BS does not allow me to read out the max epochs parameters, as that would instantiate and thus execute the objects size_t fullTotalMaxEpochs = 1; // BUGBUG: BS does not allow me to read out the max epochs parameters, as that would instantiate and thus execute the objects
// set up progress tracing for compute cluster management // set up progress tracing for compute cluster management
if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode())) if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
// MAIN LOOP that executes the actions // MAIN LOOP that executes the actions
@ -508,6 +516,8 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
const ScriptableObjects::ConfigArray& actions = actionsVal; const ScriptableObjects::ConfigArray& actions = actionsVal;
for (int i = actions.GetIndexRange().first; i <= actions.GetIndexRange().second; i++) for (int i = actions.GetIndexRange().first; i <= actions.GetIndexRange().second; i++)
{ {
// TODO: When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
// be run in parallel across multiple ranks. Others should only run on rank 0
actions.At(i, [](const wstring&) actions.At(i, [](const wstring&)
{ {
}); // this will evaluate and thus execute the action }); // this will evaluate and thus execute the action
@ -525,7 +535,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
} }
fprintf(stderr, "COMPLETED\n"), fflush(stderr); fprintf(stderr, "COMPLETED\n"), fflush(stderr);
delete g_mpi; MPIWrapper::DeleteInstance();
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
@ -546,12 +556,10 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
ConfigArray command = config(L"command", "train"); ConfigArray command = config(L"command", "train");
// paralleltrain training // paralleltrain training
g_mpi = nullptr; shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> mpi;
bool paralleltrain = config(L"parallelTrain", "false"); bool paralleltrain = config(L"parallelTrain", "false");
if (paralleltrain) if (paralleltrain)
{ mpi = MPIWrapper::GetInstance(true /*create*/);
g_mpi = new MPIWrapper();
}
g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false); g_shareNodeValueMatrices = config(L"shareNodeValueMatrices", false);
@ -569,7 +577,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
if (paralleltrain) if (paralleltrain)
{ {
std::wostringstream oss; std::wostringstream oss;
oss << g_mpi->CurrentNodeRank(); oss << mpi->CurrentNodeRank();
logpath += L"rank" + oss.str(); logpath += L"rank" + oss.str();
} }
RedirectStdErr(logpath); RedirectStdErr(logpath);
@ -616,9 +624,9 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
fprintf(stderr, "\nPrecision = \"%s\"\n", type.c_str()); fprintf(stderr, "\nPrecision = \"%s\"\n", type.c_str());
if (type == "float") if (type == "float")
DoCommands<float>(config); DoCommands<float>(config, mpi);
else if (type == "double") else if (type == "double")
DoCommands<double>(config); DoCommands<double>(config, mpi);
else else
RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str()); RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str());
@ -631,7 +639,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
} }
fprintf(stderr, "COMPLETED\n"), fflush(stderr); fprintf(stderr, "COMPLETED\n"), fflush(stderr);
delete g_mpi; MPIWrapper::DeleteInstance();
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }

Просмотреть файл

@ -10,6 +10,7 @@
#include <string> #include <string>
#include <array> #include <array>
#include <vector> #include <vector>
#include <memory>
namespace Microsoft { namespace MSR { namespace CNTK { namespace Microsoft { namespace MSR { namespace CNTK {
@ -49,7 +50,7 @@ static int operator||(int rc, const MpiFail &what)
RuntimeError("%s", what.c_str()); RuntimeError("%s", what.c_str());
} }
class MPIWrapper class MPIWrapper : public std::enable_shared_from_this<MPIWrapper>
{ {
int m_myRank; int m_myRank;
int m_numMPINodes; int m_numMPINodes;
@ -58,6 +59,8 @@ class MPIWrapper
// MPI communicator that reflects the current subset selection // MPI communicator that reflects the current subset selection
MPI_Comm m_currentComm; MPI_Comm m_currentComm;
static std::shared_ptr<MPIWrapper> s_mpi;
// MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that) // MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that)
int MPI_Init_DL() int MPI_Init_DL()
{ {
@ -100,7 +103,6 @@ class MPIWrapper
static int s_myRank; static int s_myRank;
static void MPIWorkaroundAtExit() static void MPIWorkaroundAtExit()
{ {
// Note: we can't use g_mpi, since MPI stack is already down at this point
Sleep(s_myRank * 50); Sleep(s_myRank * 50);
} }
@ -155,6 +157,7 @@ public:
MPI_Finalize(); MPI_Finalize();
} }
private:
void Ping(const char *msg) const void Ping(const char *msg) const
{ {
#undef USE2NDCOMM #undef USE2NDCOMM
@ -218,6 +221,28 @@ public:
Ping("requestnodes (after change)"); Ping("requestnodes (after change)");
} }
public:
static std::shared_ptr<MPIWrapper> GetInstance(bool create = false)
{
static bool initialized = false;
if (create)
{
if (initialized)
LogicError("Creating MPIWrapper instance after a GetInstance call has been already made!");
else
s_mpi = std::make_shared<MPIWrapper>();
}
initialized = true;
return s_mpi;
}
static void DeleteInstance()
{
s_mpi = nullptr;
}
MPI_Comm Communicator() const MPI_Comm Communicator() const
{ {
return m_currentComm; return m_currentComm;
@ -312,8 +337,5 @@ public:
MPI_Barrier(m_currentComm) || MpiFail("waitall: MPI_Barrier"); MPI_Barrier(m_currentComm) || MpiFail("waitall: MPI_Barrier");
} }
}; };
}
}
}
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi; }}}

Просмотреть файл

@ -2,3 +2,4 @@
#include "Include/MPIWrapper.h" #include "Include/MPIWrapper.h"
int Microsoft::MSR::CNTK::MPIWrapper::s_myRank = -1; int Microsoft::MSR::CNTK::MPIWrapper::s_myRank = -1;
std::shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> Microsoft::MSR::CNTK::MPIWrapper::s_mpi = nullptr;

Просмотреть файл

@ -86,18 +86,12 @@ void ComputationNetwork::SaveEdited(const wstring& fileName, const FileOptions f
void ComputationNetwork::Save(const wstring& fileName, const FileOptions fileFormat) const void ComputationNetwork::Save(const wstring& fileName, const FileOptions fileFormat) const
{ {
VerifyIsCompiled("Save"); VerifyIsCompiled("Save");
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
// TODO: This does not belong here.
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
{
// Saving into temporary file and then renaming it to the requested fileName // Saving into temporary file and then renaming it to the requested fileName
// This is a standard trick to avoid havign corrupted model files if process dies during writing // This is a standard trick to avoid havign corrupted model files if process dies during writing
wstring tmpFileName = fileName + L".tmp"; wstring tmpFileName = fileName + L".tmp";
SaveToFileImpl(tmpFileName, fileFormat); SaveToFileImpl(tmpFileName, fileFormat);
renameOrDie(tmpFileName, fileName); renameOrDie(tmpFileName, fileName);
} }
}
// TODO: how does the file distinguish float vs double nodes? // TODO: how does the file distinguish float vs double nodes?
void ComputationNetwork::SaveToFileImpl(const wstring& fileName, const FileOptions fileFormat) const void ComputationNetwork::SaveToFileImpl(const wstring& fileName, const FileOptions fileFormat) const

Просмотреть файл

@ -17,9 +17,6 @@
#include "BestGpu.h" #include "BestGpu.h"
#include "MPIWrapper.h" #include "MPIWrapper.h"
// TODO: Get rid of this global
Microsoft::MSR::CNTK::MPIWrapper* g_mpi = nullptr;
// TODO: Temporary mechanism to enable memory sharing for // TODO: Temporary mechanism to enable memory sharing for
// node output value matrices. This will go away when the // node output value matrices. This will go away when the
// sharing is ready to be enabled by default // sharing is ready to be enabled by default

Просмотреть файл

@ -30,7 +30,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
bool useDistributedMBReading, bool useDistributedMBReading,
bool useParallelTrain, bool useParallelTrain,
StreamMinibatchInputs& inputMatrices, StreamMinibatchInputs& inputMatrices,
size_t& actualMBSize) size_t& actualMBSize,
const std::shared_ptr<MPIWrapper>& mpi)
{ {
auto pMBLayout = net->GetMBLayoutPtr(); auto pMBLayout = net->GetMBLayoutPtr();
// Reading consists of a sequence of Reader API calls: // Reading consists of a sequence of Reader API calls:
@ -66,7 +67,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// decimate if needed. Decimation happens in-place. // decimate if needed. Decimation happens in-place.
if (!useDistributedMBReading && useParallelTrain) if (!useDistributedMBReading && useParallelTrain)
DecimateMinibatchInPlace<ElemType>(inputMatrices, g_mpi->NumNodesInUse(), g_mpi->CurrentNodeRank(), net->GetMBLayoutPtr()); DecimateMinibatchInPlace<ElemType>(inputMatrices, mpi->NumNodesInUse(), mpi->CurrentNodeRank(), net->GetMBLayoutPtr());
// reader will have resized input node's m_value directly. Nodes must be notified to do necessary internal state updates from that. // reader will have resized input node's m_value directly. Nodes must be notified to do necessary internal state updates from that.
// TODO: This is a stopgap. SGD will at some point change from sets of matrices to sets of nodes. Then this will become much simpler. // TODO: This is a stopgap. SGD will at some point change from sets of matrices to sets of nodes. Then this will become much simpler.

Просмотреть файл

@ -9,7 +9,7 @@ template <class ElemType>
class IDistGradAggregator class IDistGradAggregator
{ {
public: public:
IDistGradAggregator(MPIWrapper* mpi) IDistGradAggregator(const std::shared_ptr<MPIWrapper>& mpi)
: m_mpi(mpi) : m_mpi(mpi)
{ {
} }
@ -37,7 +37,7 @@ public:
} }
protected: protected:
MPIWrapper* m_mpi; std::shared_ptr<MPIWrapper> m_mpi;
}; };
#define UsingIDistGradAggregatorMembers \ #define UsingIDistGradAggregatorMembers \

Просмотреть файл

@ -102,7 +102,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{ {
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr; typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
public: public:
IMASGD(MPIWrapper* pMPI, size_t perfReportFreq) IMASGD(const std::shared_ptr<MPIWrapper>& pMPI, size_t perfReportFreq)
:m_MAworkerStatus(pMPI->NumNodesInUse(), MAWorkerStatus::NOTSTARTED), :m_MAworkerStatus(pMPI->NumNodesInUse(), MAWorkerStatus::NOTSTARTED),
m_numSyncPerformed(0), m_numSyncPerformed(0),
m_numWorkers(pMPI->NumNodesInUse()), m_numWorkers(pMPI->NumNodesInUse()),
@ -120,7 +120,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{ {
m_MAworkerStatus.resize(m_numWorkers); m_MAworkerStatus.resize(m_numWorkers);
std::fill(m_MAworkerStatus.begin(), m_MAworkerStatus.end(), MAWorkerStatus::DataProcessing); std::fill(m_MAworkerStatus.begin(), m_MAworkerStatus.end(), MAWorkerStatus::DataProcessing);
g_mpi->WaitAll(); m_pMPI->WaitAll();
m_perfReporter.OnEpochStart(); m_perfReporter.OnEpochStart();
} }
@ -286,8 +286,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
size_t m_numWorkers; size_t m_numWorkers;
size_t m_myRank; size_t m_myRank;
MASGDPerfStats m_perfReporter; MASGDPerfStats m_perfReporter;
MPIWrapper* m_pMPI; // TODO: to use shared_ptr in the future std::shared_ptr<MPIWrapper> m_pMPI;
}; };
@ -300,7 +299,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
using Base::DownCast; using Base::DownCast;
public: public:
BasicModelAveragingSGD(MPIWrapper* pMPI, size_t reportFreq) BasicModelAveragingSGD(const std::shared_ptr<MPIWrapper>& pMPI, size_t reportFreq)
:Base(pMPI, reportFreq) :Base(pMPI, reportFreq)
{} {}
@ -327,7 +326,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (nTotalSamples <= 0) if (nTotalSamples <= 0)
{ {
// prepare for overflow // prepare for overflow
factor = 1.0f / g_mpi->NumNodesInUse(); factor = 1.0f / m_pMPI->NumNodesInUse();
totalSamplesProcessed = samplesSinceLastSync * m_pMPI->NumNodesInUse(); totalSamplesProcessed = samplesSinceLastSync * m_pMPI->NumNodesInUse();
// give an estimated one // give an estimated one
} }

Просмотреть файл

@ -261,11 +261,11 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
prevLearnRates[i] = -1.0; prevLearnRates[i] = -1.0;
} }
if (m_parallelizationMethod == ParallelizationMethod::DataParallelSGD) if (GetParallelizationMethod() == ParallelizationMethod::DataParallelSGD)
{ {
InitDistGradAgg(evaluationNodes.size(), m_traceLevel); InitDistGradAgg(evaluationNodes.size(), m_traceLevel);
} }
else if (m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) else if (GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD)
{ {
InitModelAggregationHandler(m_syncStatsTrace); InitModelAggregationHandler(m_syncStatsTrace);
} }
@ -278,11 +278,14 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
{ {
// Synchronize all ranks before writing the model to ensure that // Synchronize all ranks before writing the model to ensure that
// everyone is done loading the model // everyone is done loading the model
if (g_mpi != nullptr) if (m_mpi != nullptr)
{ {
g_mpi->WaitAll(); m_mpi->WaitAll();
} }
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
net->Save(GetModelNameForEpoch(int(startEpoch) - 1)); net->Save(GetModelNameForEpoch(int(startEpoch) - 1));
} }
@ -332,9 +335,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
{ {
// Synchronize all ranks before proceeding to ensure that // Synchronize all ranks before proceeding to ensure that
// rank 0 has finished writing the previous model file // rank 0 has finished writing the previous model file
if (g_mpi != nullptr) if (m_mpi != nullptr)
{ {
g_mpi->WaitAll(); m_mpi->WaitAll();
} }
Timer timer; Timer timer;
@ -384,6 +387,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
i + 1, learnRatePerSample, m_minLearnRate); i + 1, learnRatePerSample, m_minLearnRate);
if (m_autoLearnRateSearchType != LearningRateSearchAlgorithm::None) if (m_autoLearnRateSearchType != LearningRateSearchAlgorithm::None)
{ {
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
net->Save(m_modelPath); net->Save(m_modelPath);
} }
break; break;
@ -501,7 +507,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr) if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
{ {
SimpleEvaluator<ElemType> evalforvalidation(net, g_mpi != nullptr); SimpleEvaluator<ElemType> evalforvalidation(net, m_mpi);
vector<wstring> cvSetTrainAndEvalNodes; vector<wstring> cvSetTrainAndEvalNodes;
if (criterionNodes.size() > 0) if (criterionNodes.size() > 0)
{ {
@ -535,10 +541,10 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
} }
// broadcast epochCriterion to make sure each processor will have the same learning rate schedule // broadcast epochCriterion to make sure each processor will have the same learning rate schedule
if ((m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) && (g_mpi->NumNodesInUse() > 1)) if ((GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD) && (m_mpi->NumNodesInUse() > 1))
{ {
g_mpi->Bcast(&epochCriterion, 1, g_mpi->MainNodeRank()); m_mpi->Bcast(&epochCriterion, 1, m_mpi->MainNodeRank());
g_mpi->Bcast(&lrControlCriterion, 1, g_mpi->MainNodeRank()); m_mpi->Bcast(&lrControlCriterion, 1, m_mpi->MainNodeRank());
} }
bool loadedPrevModel = false; bool loadedPrevModel = false;
@ -587,6 +593,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
} }
else else
{ {
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
net->Save(GetModelNameForEpoch(i, true)); net->Save(GetModelNameForEpoch(i, true));
fprintf(stderr, "Finished training and saved final model\n\n"); fprintf(stderr, "Finished training and saved final model\n\n");
@ -634,13 +643,13 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
// Synchronize all ranks before proceeding to ensure that // Synchronize all ranks before proceeding to ensure that
// nobody tries reading the checkpoint file at the same time // nobody tries reading the checkpoint file at the same time
// as rank 0 deleting it below // as rank 0 deleting it below
if (g_mpi != nullptr) if (m_mpi != nullptr)
{ {
g_mpi->WaitAll(); m_mpi->WaitAll();
} }
// persist model and check-point info // persist model and check-point info
if ((g_mpi == nullptr) || g_mpi->IsMainNode()) if ((m_mpi == nullptr) || m_mpi->IsMainNode())
{ {
SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion, chosenMinibatchSize); SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion, chosenMinibatchSize);
auto modelName = GetModelNameForEpoch(i); auto modelName = GetModelNameForEpoch(i);
@ -677,9 +686,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
// Synchronize all ranks before proceeding to ensure that // Synchronize all ranks before proceeding to ensure that
// rank 0 has finished writing the model file // rank 0 has finished writing the model file
if (g_mpi != nullptr) if (m_mpi != nullptr)
{ {
g_mpi->WaitAll(); m_mpi->WaitAll();
} }
// progress tracing for compute cluster management // progress tracing for compute cluster management
@ -746,9 +755,9 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
localEpochCriterion.SetValue(0); localEpochCriterion.SetValue(0);
localEpochEvalErrors.SetValue(0); localEpochEvalErrors.SetValue(0);
bool useGradientAggregation = ((m_parallelizationMethod == ParallelizationMethod::DataParallelSGD) && bool useGradientAggregation = ((GetParallelizationMethod() == ParallelizationMethod::DataParallelSGD) &&
(epochNumber >= m_parallelizationStartEpochNum)); (epochNumber >= m_parallelizationStartEpochNum));
bool useModelAveraging = ((m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) && bool useModelAveraging = ((GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD) &&
(epochNumber >= m_parallelizationStartEpochNum)); (epochNumber >= m_parallelizationStartEpochNum));
bool useParallelTrain = useGradientAggregation || useModelAveraging; bool useParallelTrain = useGradientAggregation || useModelAveraging;
@ -778,8 +787,8 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
trainSetDataReader->SupportsDistributedMBRead(); trainSetDataReader->SupportsDistributedMBRead();
if (useDistributedMBReading) if (useDistributedMBReading)
{ {
trainSetDataReader->StartDistributedMinibatchLoop(tunedMBSize, epochNumber, g_mpi->CurrentNodeRank(), trainSetDataReader->StartDistributedMinibatchLoop(tunedMBSize, epochNumber, m_mpi->CurrentNodeRank(),
g_mpi->NumNodesInUse(), epochSize); m_mpi->NumNodesInUse(), epochSize);
} }
else else
{ {
@ -813,7 +822,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
if (useGradientAggregation) if (useGradientAggregation)
{ {
fprintf(stderr, ", DataParallelSGD training (MyRank = %d, NumNodes = %d, NumGradientBits = %d)", fprintf(stderr, ", DataParallelSGD training (MyRank = %d, NumNodes = %d, NumGradientBits = %d)",
(int) g_mpi->CurrentNodeRank(), (int) g_mpi->NumNodesInUse(), (int) m_numGradientBits); (int) m_mpi->CurrentNodeRank(), (int) m_mpi->NumNodesInUse(), (int) m_numGradientBits);
if (m_bufferedAsyncGradientAggregation) if (m_bufferedAsyncGradientAggregation)
{ {
fprintf(stderr, ", BufferedAsyncGradientAggregation is ENABLED"); fprintf(stderr, ", BufferedAsyncGradientAggregation is ENABLED");
@ -844,7 +853,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers? // TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
size_t actualMBSize = 0; size_t actualMBSize = 0;
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0], bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize); useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
break; // end of epoch break; // end of epoch
@ -1194,11 +1203,11 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
} }
// in case of model averaging, do one more final aggregation of criteria // in case of model averaging, do one more final aggregation of criteria
if (useModelAveraging && (g_mpi->NumNodesInUse() > 1)) if (useModelAveraging && (m_mpi->NumNodesInUse() > 1))
{ {
// 1. total epoch samples processed by all workers // 1. total epoch samples processed by all workers
size_t totalEpochSamplesOfAllWorkers = totalEpochSamples; size_t totalEpochSamplesOfAllWorkers = totalEpochSamples;
g_mpi->AllReduce(&totalEpochSamplesOfAllWorkers, 1); m_mpi->AllReduce(&totalEpochSamplesOfAllWorkers, 1);
totalSamplesSeen += totalEpochSamplesOfAllWorkers; totalSamplesSeen += totalEpochSamplesOfAllWorkers;
// 2. criterion and EvalErrors // 2. criterion and EvalErrors
@ -1211,8 +1220,8 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
epochEvalErrors[i] = localEpochEvalErrors(0, i); epochEvalErrors[i] = localEpochEvalErrors(0, i);
} }
// merge epochCriterion and epochEvalErrors over nodes // merge epochCriterion and epochEvalErrors over nodes
g_mpi->AllReduce(&epochCriterion, 1); m_mpi->AllReduce(&epochCriterion, 1);
g_mpi->AllReduce(epochEvalErrors); m_mpi->AllReduce(epochEvalErrors);
// 3. modify return value // 3. modify return value
totalEpochSamples = totalEpochSamplesOfAllWorkers; totalEpochSamples = totalEpochSamplesOfAllWorkers;
@ -1298,7 +1307,7 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
const size_t numIterationsBeforePrintingProgress = 100; const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0; size_t numItersSinceLastPrintOfProgress = 0;
size_t actualMBSizeDummy; size_t actualMBSizeDummy;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy)) while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi))
{ {
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'? // TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
ComputationNetwork::BumpEvalTimeStamp(featureNodes); ComputationNetwork::BumpEvalTimeStamp(featureNodes);
@ -1821,19 +1830,19 @@ int SGD<ElemType>::SGDTrace(FILE* __restrict __stream, const char* __restrict __
template <class ElemType> template <class ElemType>
void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int traceLevel) void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int traceLevel)
{ {
if (m_parallelizationMethod == ParallelizationMethod::DataParallelSGD) if (GetParallelizationMethod() == ParallelizationMethod::DataParallelSGD)
{ {
if (m_distGradAgg == nullptr) if (m_distGradAgg == nullptr)
{ {
#ifdef QUANTIZED_GRADIENT_AGGREGATION #ifdef QUANTIZED_GRADIENT_AGGREGATION
m_distGradAgg = new AllReduceDistGradAggregator<ElemType>(g_mpi, m_numGradientBits, m_zeroThresholdFor1Bit, true /*useQuantizationForSelfStripe*/, m_bufferedAsyncGradientAggregation, traceLevel, m_syncStatsTrace); m_distGradAgg = new AllReduceDistGradAggregator<ElemType>(m_mpi, m_numGradientBits, m_zeroThresholdFor1Bit, true /*useQuantizationForSelfStripe*/, m_bufferedAsyncGradientAggregation, traceLevel, m_syncStatsTrace);
#else #else
if (m_numGradientBits != (8 * sizeof(ElemType))) if (m_numGradientBits != (8 * sizeof(ElemType)))
{ {
RuntimeError("Gradient quantization is unsupported in CNTK binaries built without quantized gradient aggregation support!"); RuntimeError("Gradient quantization is unsupported in CNTK binaries built without quantized gradient aggregation support!");
} }
m_distGradAgg = new SimpleDistGradAggregator<ElemType>(g_mpi, m_bufferedAsyncGradientAggregation, m_syncStatsTrace); m_distGradAgg = new SimpleDistGradAggregator<ElemType>(m_mpi, m_bufferedAsyncGradientAggregation, m_syncStatsTrace);
#endif // !QUANTIZED_GRADIENT_AGGREGATION #endif // !QUANTIZED_GRADIENT_AGGREGATION
} }
@ -1847,12 +1856,12 @@ void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int traceLevel)
template <class ElemType> template <class ElemType>
void SGD<ElemType>::InitModelAggregationHandler(int traceLevel) void SGD<ElemType>::InitModelAggregationHandler(int traceLevel)
{ {
if (m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) if (GetParallelizationMethod() == ParallelizationMethod::ModelAveragingSGD)
{ {
#ifndef BLOCKWISE_MODEL_UPDATE_FILTERING #ifndef BLOCKWISE_MODEL_UPDATE_FILTERING
if (!m_pMASGDHelper) if (!m_pMASGDHelper)
{ {
m_pMASGDHelper = make_shared<BasicModelAveragingSGD<ElemType>>(g_mpi, traceLevel); m_pMASGDHelper = make_shared<BasicModelAveragingSGD<ElemType>>(m_mpi, traceLevel);
} }
#else #else
@ -2011,7 +2020,7 @@ void SGD<ElemType>::SaveCheckPointInfo(const size_t epoch, const size_t totalSam
{ {
// In case of parallel training only the main node should we saving the checkpoint to prevent // In case of parallel training only the main node should we saving the checkpoint to prevent
// the parallel training nodes from colliding to write the same file // the parallel training nodes from colliding to write the same file
if ((g_mpi == nullptr) || g_mpi->IsMainNode()) if ((m_mpi == nullptr) || m_mpi->IsMainNode())
{ {
wstring checkPointFileName = GetCheckPointFileNameForEpoch(int(epoch)); wstring checkPointFileName = GetCheckPointFileNameForEpoch(int(epoch));
// Saving into temporary file and then renaming it to the checkPointFileName // Saving into temporary file and then renaming it to the checkPointFileName
@ -2538,7 +2547,7 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_parallelizationStartEpochNum = 0; m_parallelizationStartEpochNum = 0;
m_nFramesBetweenMASync = 40000; // default 40k frames m_nFramesBetweenMASync = 40000; // default 40k frames
if ((g_mpi != nullptr) && configSGD.Exists(L"ParallelTrain")) if (configSGD.Exists(L"ParallelTrain"))
{ {
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record())); const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none")); m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));

Просмотреть файл

@ -143,6 +143,14 @@ protected:
return pow(m_momentumParam[epoch], 1.0 / FixUpEffectiveMBSize(m_momentumSpecifiedForMBSize[epoch], numParallelSequences)); return pow(m_momentumParam[epoch], 1.0 / FixUpEffectiveMBSize(m_momentumSpecifiedForMBSize[epoch], numParallelSequences));
} }
ParallelizationMethod GetParallelizationMethod() const
{
if (m_mpi == nullptr)
return ParallelizationMethod::None;
return m_parallelizationMethod;
}
// only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1 // only true when the user specify LearningRatePerMB and the number of parallel utterances in Reader > 1
// bool m_needToNormalizeLRByParallUtterance; // TODO: should go away // bool m_needToNormalizeLRByParallUtterance; // TODO: should go away
// bool m_needToNormalizeMomentumByParallUtterance; // bool m_needToNormalizeMomentumByParallUtterance;
@ -228,6 +236,8 @@ protected:
bool m_useAllDataForPreComputedNode; bool m_useAllDataForPreComputedNode;
// Parallel training // Parallel training
std::shared_ptr<MPIWrapper> m_mpi;
ParallelizationMethod m_parallelizationMethod; ParallelizationMethod m_parallelizationMethod;
bool m_enableDistributedMBReading; bool m_enableDistributedMBReading;
int m_parallelizationStartEpochNum; int m_parallelizationStartEpochNum;
@ -303,6 +313,14 @@ public:
{ {
} }
void InitMPI(const std::shared_ptr<MPIWrapper>& mpi)
{
m_mpi = mpi;
if (m_mpi == nullptr)
m_parallelizationMethod = ParallelizationMethod::None;
}
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId, void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
IDataReader* trainSetDataReader, IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader, IDataReader* validationSetDataReader,

Просмотреть файл

@ -15,7 +15,7 @@ class SimpleDistGradAggregator : public IDistGradAggregator<ElemType>
UsingIDistGradAggregatorMembers; UsingIDistGradAggregatorMembers;
public: public:
SimpleDistGradAggregator(MPIWrapper* mpi, bool useAsyncAggregation, int syncStatsTrace) SimpleDistGradAggregator(const std::shared_ptr<MPIWrapper>& mpi, bool useAsyncAggregation, int syncStatsTrace)
: IDistGradAggregator<ElemType>(mpi), m_useAsyncAggregation(useAsyncAggregation), m_currentEpochNumber(-1), m_bufferedGradHeader(nullptr), m_syncStatsTrace(syncStatsTrace), m_iterationCount(0) : IDistGradAggregator<ElemType>(mpi), m_useAsyncAggregation(useAsyncAggregation), m_currentEpochNumber(-1), m_bufferedGradHeader(nullptr), m_syncStatsTrace(syncStatsTrace), m_iterationCount(0)
{ {
} }

Просмотреть файл

@ -31,14 +31,14 @@ template <class ElemType>
class SimpleEvaluator class SimpleEvaluator
{ {
public: public:
SimpleEvaluator(ComputationNetworkPtr net, const bool parallelRun, const size_t numMBsToShowResult = 100, const int traceLevel = 0, const size_t maxSamplesInRAM = SIZE_MAX, SimpleEvaluator(ComputationNetworkPtr net, const std::shared_ptr<MPIWrapper>& mpi, const size_t numMBsToShowResult = 100, const int traceLevel = 0, const size_t maxSamplesInRAM = SIZE_MAX,
const size_t numSubminiBatches = 1) const size_t numSubminiBatches = 1)
: m_net(net), : m_net(net),
m_numMBsToShowResult(numMBsToShowResult), m_numMBsToShowResult(numMBsToShowResult),
m_traceLevel(traceLevel), m_traceLevel(traceLevel),
m_maxSamplesInRAM(maxSamplesInRAM), m_maxSamplesInRAM(maxSamplesInRAM),
m_numSubminiBatches(numSubminiBatches), m_numSubminiBatches(numSubminiBatches),
m_parallelRun(parallelRun), m_mpi(mpi),
m_distGradAgg(nullptr), m_distGradAgg(nullptr),
m_gradHeader(nullptr) m_gradHeader(nullptr)
{ {
@ -123,7 +123,7 @@ public:
const size_t numIterationsBeforePrintingProgress = 100; const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0; size_t numItersSinceLastPrintOfProgress = 0;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, false, m_parallelRun, inputMatrices, actualMBSize)) while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, false, m_mpi != nullptr, inputMatrices, actualMBSize, m_mpi))
{ {
size_t actualNumSubminibatches = numSubminibatchesNeeded <= 1 ? 1 : smbDispatcher.GetMinibatchIntoCache(*dataReader, *m_net, inputMatrices, numSubminibatchesNeeded); size_t actualNumSubminibatches = numSubminibatchesNeeded <= 1 ? 1 : smbDispatcher.GetMinibatchIntoCache(*dataReader, *m_net, inputMatrices, numSubminibatchesNeeded);
for (size_t ismb = 0; ismb < actualNumSubminibatches; ismb++) for (size_t ismb = 0; ismb < actualNumSubminibatches; ismb++)
@ -148,12 +148,12 @@ public:
size_t numSamplesWithLabel = m_net->GetNumSamplesWithLabel(actualMBSize); size_t numSamplesWithLabel = m_net->GetNumSamplesWithLabel(actualMBSize);
size_t aggregateNumSamplesWithLabel = numSamplesWithLabel; size_t aggregateNumSamplesWithLabel = numSamplesWithLabel;
if (m_parallelRun) if (m_mpi != nullptr)
{ {
if (m_gradHeader == nullptr) if (m_gradHeader == nullptr)
{ {
m_gradHeader = DistGradHeader::Create(evalNodes.size()); m_gradHeader = DistGradHeader::Create(evalNodes.size());
m_distGradAgg = make_shared<SimpleDistGradAggregator<ElemType>>(g_mpi, false, m_traceLevel); m_distGradAgg = make_shared<SimpleDistGradAggregator<ElemType>>(m_mpi, false, m_traceLevel);
} }
m_gradHeader->numEvalNode = evalNodes.size(); m_gradHeader->numEvalNode = evalNodes.size();
@ -287,7 +287,7 @@ protected:
size_t m_numMBsToShowResult; size_t m_numMBsToShowResult;
size_t m_maxSamplesInRAM; size_t m_maxSamplesInRAM;
size_t m_numSubminiBatches; size_t m_numSubminiBatches;
bool m_parallelRun; std::shared_ptr<MPIWrapper> m_mpi;
shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg; shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
struct DistGradHeader* m_gradHeader; struct DistGradHeader* m_gradHeader;

Просмотреть файл

@ -107,7 +107,7 @@ public:
const size_t numIterationsBeforePrintingProgress = 100; const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0; size_t numItersSinceLastPrintOfProgress = 0;
size_t actualMBSize; size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize)) while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
{ {
ComputationNetwork::BumpEvalTimeStamp(inputNodes); ComputationNetwork::BumpEvalTimeStamp(inputNodes);
@ -237,7 +237,7 @@ public:
size_t actualMBSize; size_t actualMBSize;
const size_t numIterationsBeforePrintingProgress = 100; const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0; size_t numItersSinceLastPrintOfProgress = 0;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize)) while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
{ {
ComputationNetwork::BumpEvalTimeStamp(inputNodes); ComputationNetwork::BumpEvalTimeStamp(inputNodes);