Avoid code duplication by refactoring WaitAll into a separate function.
This commit is contained in:
Родитель
249989b95f
Коммит
8fad703f27
|
@ -425,15 +425,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
{
|
||||
// Synchronize all ranks before proceeding to ensure that
|
||||
// rank 0 has finished writing the previous model file
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_mpi->WaitAll();
|
||||
}
|
||||
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_pASGDHelper->WaitAll();
|
||||
}
|
||||
BarrierWorkers();
|
||||
|
||||
// (re-)initialize 1-bit SGD
|
||||
if (GetParallelizationMethod() == ParallelizationMethod::dataParallelSGD &&
|
||||
|
@ -597,7 +589,8 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
|
||||
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
|
||||
{
|
||||
// TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD.
|
||||
// TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD, since Multiverso has another background thread using MPI.
|
||||
// Making the evaluation serial (non-distributed) will slowdown training especially when validation set is large.
|
||||
SimpleEvaluator<ElemType> evalforvalidation(net, UsingAsyncGradientAggregation(i + 1) ?nullptr : m_mpi, m_enableDistributedMBReading);
|
||||
vector<wstring> cvSetTrainAndEvalNodes;
|
||||
if (criterionNodes.size() > 0)
|
||||
|
@ -735,16 +728,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
// Synchronize all ranks before proceeding to ensure that
|
||||
// nobody tries reading the checkpoint file at the same time
|
||||
// as rank 0 deleting it below
|
||||
// TODO[DataASGD]: worker in async-mode didn't wait for the rank 0
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_mpi->WaitAll();
|
||||
}
|
||||
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_pASGDHelper->WaitAll();
|
||||
}
|
||||
BarrierWorkers();
|
||||
|
||||
// Persist model and check-point info
|
||||
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
|
||||
|
@ -813,15 +797,8 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
// Synchronize all ranks before proceeding to ensure that
|
||||
// rank 0 has finished writing the model file
|
||||
// TODO[DataASGD]: should othet other rank waiting in async-mode
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_mpi->WaitAll();
|
||||
}
|
||||
BarrierWorkers();
|
||||
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_pASGDHelper->WaitAll();
|
||||
}
|
||||
// progress tracing for compute cluster management
|
||||
ProgressTracing::TraceProgressPercentage(m_maxEpochs, 0.0, true);
|
||||
ProgressTracing::TraceTrainLoss(m_lastFinishedEpochTrainLoss);
|
||||
|
|
|
@ -579,6 +579,7 @@ private:
|
|||
{
|
||||
return ((GetParallelizationMethod() == ParallelizationMethod::dataParallelSGD) && (epochNumber >= m_parallelizationStartEpochNum));
|
||||
}
|
||||
|
||||
bool UsingModelAggregation(size_t epochNumber) const
|
||||
{
|
||||
return ((GetParallelizationMethod() == ParallelizationMethod::modelAveragingSGD ||
|
||||
|
@ -590,10 +591,24 @@ private:
|
|||
{
|
||||
return ((GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD) && (epochNumber >= m_parallelizationStartEpochNum));
|
||||
}
|
||||
|
||||
bool UsingParallelTrain(size_t epochNumber)
|
||||
{
|
||||
return UsingGradientAggregation(epochNumber) || UsingModelAggregation(epochNumber) || UsingAsyncGradientAggregation(epochNumber);
|
||||
}
|
||||
|
||||
void BarrierWorkers()
|
||||
{
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() != ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_mpi->WaitAll();
|
||||
}
|
||||
if (m_mpi != nullptr && GetParallelizationMethod() == ParallelizationMethod::dataParallelASGD)
|
||||
{
|
||||
m_pASGDHelper->WaitAll();
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
}}}
|
||||
|
|
Загрузка…
Ссылка в новой задаче