// // // Copyright (c) Microsoft Corporation. All rights reserved. // // #pragma once #include "Basics.h" #include "ComputationNetwork.h" #include "Config.h" #include "SGD.h" #include "Matrix.h" #include "MPIWrapper.h" #include "TimerUtility.h" #include #include #include #include #include namespace Microsoft { namespace MSR { namespace CNTK { enum class MAWorkerStatus { DataProcessing = 0, DataEnd = 1, NOTSTARTED = 2 }; class MASGDPerfStats { private: size_t m_numWorkers; size_t m_myRank; size_t m_numSyncPerformedInCurrentEpoch; size_t m_reportFrequency; size_t m_totalSamplesProcessedSinceLastReport; size_t m_localSamplesProcessedSinceLastReport; double m_accumulatedSecondsOnSyncPointInOneEpoch; size_t m_syncPointHitCounterInOneEpoch; Timer m_Timer; public: MASGDPerfStats(size_t myRank, size_t numWorkers): m_numWorkers(numWorkers), m_myRank(myRank), m_numSyncPerformedInCurrentEpoch(0), m_reportFrequency(1), m_totalSamplesProcessedSinceLastReport(0), m_localSamplesProcessedSinceLastReport(0) { m_Timer.Start(); } void SetReportFrequency(size_t freq) { m_reportFrequency = freq; } void OnEpochStart() { m_Timer.Restart(); m_numSyncPerformedInCurrentEpoch = 0; m_accumulatedSecondsOnSyncPointInOneEpoch = 0; m_syncPointHitCounterInOneEpoch = 0; } void OnEpochEnd() { m_Timer.Stop(); } void OnMAPerformed(size_t localSamplesProcessedSinceLastSync, size_t totalSamplesProcessedSinceLastSync, float secondsOnCommunication) { m_numSyncPerformedInCurrentEpoch++; m_totalSamplesProcessedSinceLastReport += totalSamplesProcessedSinceLastSync; m_localSamplesProcessedSinceLastReport += localSamplesProcessedSinceLastSync; if ( m_reportFrequency > 0 && ( m_numSyncPerformedInCurrentEpoch % m_reportFrequency == 0 || m_numSyncPerformedInCurrentEpoch <=5 ) ) // reporting condition: // 1. if m_reportFrequency == 0 , no reporting // 2. if m_reportFrequence >0 , report MA perf Stats every m_reportFrequency model aggregation are performed // and the first 5 perf stats within each epoch is always reported { ReportMAPerfStats(m_totalSamplesProcessedSinceLastReport, m_localSamplesProcessedSinceLastReport, secondsOnCommunication ); m_totalSamplesProcessedSinceLastReport = 0; m_localSamplesProcessedSinceLastReport = 0; } } void OnArriveAtSyncPoint(double secondOnSyncPoint, bool printMessage) { if (printMessage) { m_accumulatedSecondsOnSyncPointInOneEpoch += secondOnSyncPoint; m_syncPointHitCounterInOneEpoch++; fprintf(stderr, "\t\t(model aggregation stats): %d-th sync point was hit, introducing a %.2f-seconds latency this time; accumulated time on sync point = %.2f seconds , average latency = %.2f seconds\n", (int)m_syncPointHitCounterInOneEpoch, secondOnSyncPoint, m_accumulatedSecondsOnSyncPointInOneEpoch, m_accumulatedSecondsOnSyncPointInOneEpoch / m_syncPointHitCounterInOneEpoch); } } void ReportMAPerfStats( size_t totalSamplesProcessedSinceLastReport, size_t localSamplesProcessedSinceLastReport, float secondOnCommunication) { m_Timer.Stop(); double secondsSinceLastReport = m_Timer.ElapsedSeconds(); m_Timer.Restart(); float totalThroughput = secondsSinceLastReport > 0 ? (float)totalSamplesProcessedSinceLastReport / ((float)secondsSinceLastReport * 1000.0f) : 0.0f ; float throughputPerWorker = totalThroughput / m_numWorkers; string prefix = "\t\t(model aggregation stats) %d-th sync: %8.2f seconds since last report (%.2f seconds on comm.); %d samples processed by %d workers (%d by me);\n" "\t\t(model aggregation stats) %d-th sync: totalThroughput = %.2fk samplesPerSecond , throughputPerWorker = %.2fk samplesPerSecond\n"; fprintf(stderr, prefix.c_str(), (int)m_numSyncPerformedInCurrentEpoch, secondsSinceLastReport, secondOnCommunication, (int)totalSamplesProcessedSinceLastReport, (int)m_numWorkers, (int)localSamplesProcessedSinceLastReport, (int)m_numSyncPerformedInCurrentEpoch, totalThroughput, throughputPerWorker); } }; // base class for MA-SGD algorithm family template class IMASGD { typedef shared_ptr> ComputationNodePtr; public: IMASGD(const MPIWrapperPtr& pMPI, size_t perfReportFreq, DEVICEID_TYPE devId) : m_MAworkerStatus(pMPI->NumNodesInUse(), MAWorkerStatus::NOTSTARTED), m_numSyncPerformed(0), m_numWorkers(pMPI->NumNodesInUse()), m_myRank(pMPI->CurrentNodeRank()), m_pMPI(pMPI), m_deviceId(devId), m_perfReporter(pMPI->CurrentNodeRank(), pMPI->NumNodesInUse()) { m_perfReporter.SetReportFrequency(perfReportFreq); } virtual ~IMASGD() { } virtual void OnEpochStart(const std::list& /*LearnableNodes*/) { m_MAworkerStatus.resize(m_numWorkers); std::fill(m_MAworkerStatus.begin(), m_MAworkerStatus.end(), MAWorkerStatus::DataProcessing); m_pMPI->WaitAll(); m_perfReporter.OnEpochStart(); } virtual void OnEpochEnd(const std::list& LearnableNodes, std::list>& smoothedGradient, size_t samplesSinceLastSync ) { m_MAworkerStatus[m_myRank] = MAWorkerStatus::DataEnd; Timer syncPointTimer; syncPointTimer.Start(); bool read2sync = UpdateWorkerStatus(MAWorkerStatus::DataEnd); syncPointTimer.Stop(); m_perfReporter.OnArriveAtSyncPoint(syncPointTimer.ElapsedSeconds(), true); // assert(read2sync); size_t totalSamplesProcessed = 0; float secondsOnCommunication = 0.0f; if (read2sync) { m_numSyncPerformed++; ModelAggregationProcessing(samplesSinceLastSync, LearnableNodes, smoothedGradient, totalSamplesProcessed, secondsOnCommunication); m_perfReporter.OnMAPerformed(samplesSinceLastSync, totalSamplesProcessed, secondsOnCommunication); } m_pMPI->WaitAll(); m_perfReporter.OnEpochEnd(); } virtual bool OnArrivingAtSyncPoint( const std::list& LearnableNodes, /* input/output: */ std::list>& smoothedGradient, /* input/output: under some setup, it will reset to zero*/ size_t samplesSinceLastSync /* input: samples processed since last sync on this worker only */ ) { Timer syncPointTimer; syncPointTimer.Start(); bool read2Sync=UpdateWorkerStatus(MAWorkerStatus::DataProcessing); syncPointTimer.Stop(); m_perfReporter.OnArriveAtSyncPoint(syncPointTimer.ElapsedSeconds(),read2Sync); size_t totalSamplesProcessed=0; float secondsOnCommunication = 0.0f; if (read2Sync) { m_numSyncPerformed++; ModelAggregationProcessing(samplesSinceLastSync, LearnableNodes, smoothedGradient, totalSamplesProcessed, secondsOnCommunication); m_perfReporter.OnMAPerformed(samplesSinceLastSync, totalSamplesProcessed, secondsOnCommunication); } return read2Sync; } virtual void ModelAggregationProcessing( size_t samplesSinceLastSync, /* in: */ const std::list& learnableNodes, /* in/out */ std::list>& smoothedGradient, /* in/out */ size_t& totalSamplesProcessed, /* out */ float& secondsOnCommunication /* out */) = 0; virtual void SaveToCheckPoint(File& fstream){} virtual void LoadFromCheckPoint(File& fstream){} protected: bool somePeersHaveArrivedAtEnd() { auto iter = std::find(m_MAworkerStatus.begin(), m_MAworkerStatus.end(), MAWorkerStatus::DataEnd); return iter != m_MAworkerStatus.end(); } bool UpdateWorkerStatus(MAWorkerStatus myStatus) { bool retval = false; m_MAworkerStatus[m_myRank] = myStatus; if (myStatus == MAWorkerStatus::DataEnd) { // in this case, we always return true vector sendRequests(m_numWorkers); int sentSignal = (int)MAWorkerStatus::DataEnd; // 1. send my status to notify peers for (int dest = 0; dest < (int)m_numWorkers; dest++) { if (dest != m_myRank) { MPI_Isend(&sentSignal, 1, MPI_INT, dest, m_numSyncPerformed, m_pMPI->Communicator() , &sendRequests[dest]); } } // 2. recv others for (int src = 0; src < m_numWorkers; src++) { if (src != m_myRank && m_MAworkerStatus[src] == MAWorkerStatus::DataProcessing) { int recvSignal = 0; MPI_Status status; MPI_Recv(&recvSignal, 1, MPI_INT, src, m_numSyncPerformed, m_pMPI->Communicator(), &status); m_MAworkerStatus[src] = (MAWorkerStatus)recvSignal; #if 0 assert(status.MPI_SOURCE == src); assert(status.MPI_TAG == m_numSyncPerformed); #endif } } // 3. make sure sending operation finished for (int dest = 0; dest < m_numWorkers; dest++) { if (dest != m_myRank) { MPI_Wait(&sendRequests[dest], MPI_STATUS_IGNORE); } } retval = true; } else if (myStatus == MAWorkerStatus::DataProcessing) { // in this case, we return true if all nodes are ready to sync (meaning all of them are in DataProcessing State) // otherwise, return false retval = false; if (!somePeersHaveArrivedAtEnd()) { int sentSignal = (int)MAWorkerStatus::DataProcessing; vector sendRequests(m_numWorkers); // 1. send my status to peers for (int dest = 0; dest < (int)m_numWorkers; dest++) { if (dest != m_myRank) { MPI_Isend(&sentSignal, 1, MPI_INT, dest, m_numSyncPerformed, m_pMPI->Communicator(), &sendRequests[dest]); } } // 2. recv status from others (blocking call) for (int src = 0; src < (int)m_numWorkers; src++) { if (src != m_myRank) { int recvSignal = 0; MPI_Status status; MPI_Recv(&recvSignal, 1, MPI_INT, src, m_numSyncPerformed, m_pMPI->Communicator(), &status); #if 0 // for debugging purpose, to be removed when mature assert(status.MPI_SOURCE == src); assert(status.MPI_TAG == m_numSyncPerformed); #endif m_MAworkerStatus[src] = (MAWorkerStatus)recvSignal; } } // 3. makes sure the sending operation has completed for (int dest = 0; dest < (int)m_numWorkers; dest++) { if (dest != m_myRank) { MPI_Wait(&sendRequests[dest], MPI_STATUS_IGNORE); } } // 4. check peer status again retval = !somePeersHaveArrivedAtEnd(); } } else { LogicError("UpdateWorkerStatus cannot accept WorkerStatus other than DataProcessing or DataEnd\n"); } return retval; } // borrow DownCast function from ComputationNetwork ComputationNodePtr DownCast(ComputationNodeBasePtr inode) { ComputationNodePtr node = dynamic_pointer_cast>(inode); if (!node) InvalidArgument("an ComputationNodeBasePtr of mismatching precision was passed"); return node; } std::vector m_MAworkerStatus; int m_numSyncPerformed; size_t m_numWorkers; size_t m_myRank; MASGDPerfStats m_perfReporter; MPIWrapperPtr m_pMPI; DEVICEID_TYPE m_deviceId; }; // Implementation of standard model averaging template class BasicModelAveragingSGD : public IMASGD { typedef IMASGD Base; using Base::m_pMPI; using Base::DownCast; public: BasicModelAveragingSGD(const MPIWrapperPtr& pMPI, size_t reportFreq, DEVICEID_TYPE devID) : Base(pMPI, reportFreq, devID) { fprintf(stderr, "Parallel training (%d workers) using ModelAveraging\n",(int)m_pMPI->NumNodesInUse()); } void ModelAggregationProcessing( size_t samplesSinceLastSync, /* in */ const std::list& learnableNodes, /* in/out */ std::list>& smoothedGradient, /* in/out */ size_t& totalSamplesProcessed, /* out */ float& secondsOnCommunication /* out */) override // NOTE: the variable type is determined by the interface in SGD::TrainOneEpoch // even for const std::list, the object being pointed to can still be modified { //---------------------------------------- // 1. communicate with other nodes to negotiate contribution weights //---------------------------------------- float factor = 0; int nTotalSamples = samplesSinceLastSync; Timer commTimer; secondsOnCommunication = 0.0f; commTimer.Start(); m_pMPI->AllReduce(&nTotalSamples, 1); commTimer.Stop(); secondsOnCommunication += (float)commTimer.ElapsedSeconds(); if (nTotalSamples <= 0) { // prepare for overflow factor = 1.0f / m_pMPI->NumNodesInUse(); totalSamplesProcessed = samplesSinceLastSync * m_pMPI->NumNodesInUse(); // give an estimated one } else { factor = (samplesSinceLastSync + 0.0f) / nTotalSamples; totalSamplesProcessed = nTotalSamples; } //---------------------------------------- // 2. process for each individual node //---------------------------------------- for (auto& pBaseNode : learnableNodes) { if (!pBaseNode->IsParameterUpdateRequired()) { continue; } // 2.1 model averaging auto pNode = DownCast(pBaseNode); // 2.1.1. average model from individual models Matrix mat(pNode->Value().DeepClone()); // pNode->Value returns lvalue, so a deep copy is invoked here // 2.1.2. normalize the weight matrix Matrix::Scale(factor, mat); // 2.1.3. send weight matrix over MPI nodes; unique_ptr px(mat.CopyToArray()); //ElemType* px = mat.CopyToArray(); size_t nx = mat.GetNumElements(); // 2.1.4. inplace sum commTimer.Restart(); m_pMPI->AllReduce(px.get(), nx); commTimer.Stop(); secondsOnCommunication += (float)commTimer.ElapsedSeconds(); // 2.1.5. set value pNode->Value().SetValue(mat.GetNumRows(), mat.GetNumCols(), mat.GetDeviceId(), px.get()); // 2.1.6. clean up //delete[]px; } } }; } } }