Avoid code duplication by refactoring WaitAll into a separate function.

This commit is contained in:
unknown 2016-11-11 11:00:32 +08:00
Родитель 249989b95f
Коммит 8fad703f27
2 изменённых файлов: 20 добавлений и 28 удалений

Просмотреть файл

@ -425,15 +425,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
{
// Synchronize all ranks before proceeding to ensure that
// rank 0 has finished writing the previous model file
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
{
m_mpi->WaitAll();
}
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
{
m_pASGDHelper->WaitAll();
}
BarrierWorkers();
// (re-)initialize 1-bit SGD
if (GetParallelizationMethod() == ParallelizationMethod::dataParallelSGD &&
@ -597,7 +589,8 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
{
// TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD.
// TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD, since Multiverso has another background thread using MPI.
// Making the evaluation serial (non-distributed) will slowdown training especially when validation set is large.
SimpleEvaluator<ElemType> evalforvalidation(net, UsingAsyncGradientAggregation(i + 1) ?nullptr : m_mpi, m_enableDistributedMBReading);
vector<wstring> cvSetTrainAndEvalNodes;
if (criterionNodes.size() > 0)
@ -735,16 +728,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
// Synchronize all ranks before proceeding to ensure that
// nobody tries reading the checkpoint file at the same time
// as rank 0 deleting it below
// TODO[DataASGD]: worker in async-mode didn't wait for the rank 0
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
{
m_mpi->WaitAll();
}
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
{
m_pASGDHelper->WaitAll();
}
BarrierWorkers();
// Persist model and check-point info
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
@ -813,15 +797,8 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
// Synchronize all ranks before proceeding to ensure that
// rank 0 has finished writing the model file
// TODO[DataASGD]: should othet other rank waiting in async-mode
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
{
m_mpi->WaitAll();
}
BarrierWorkers();
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
{
m_pASGDHelper->WaitAll();
}
// progress tracing for compute cluster management
ProgressTracing::TraceProgressPercentage(m_maxEpochs, 0.0, true);
ProgressTracing::TraceTrainLoss(m_lastFinishedEpochTrainLoss);

Просмотреть файл

@ -579,6 +579,7 @@ private:
{
return ((GetParallelizationMethod() == ParallelizationMethod::dataParallelSGD) && (epochNumber >= m_parallelizationStartEpochNum));
}
bool UsingModelAggregation(size_t epochNumber) const
{
return ((GetParallelizationMethod() == ParallelizationMethod::modelAveragingSGD ||
@ -590,10 +591,24 @@ private:
{
return ((GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD) && (epochNumber >= m_parallelizationStartEpochNum));
}
bool UsingParallelTrain(size_t epochNumber)
{
return UsingGradientAggregation(epochNumber) || UsingModelAggregation(epochNumber) || UsingAsyncGradientAggregation(epochNumber);
}
void BarrierWorkers()
{
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
{
m_mpi->WaitAll();
}
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
{
m_pASGDHelper->WaitAll();
}
return;
}
};
}}}