Merge branch 'qiwye/asgd-dev' into qiwye/asgd-exp
Conflicts: Source/Multiverso
This commit is contained in:
Коммит
42b3b7a214
|
@ -74,8 +74,8 @@ class MPIWrapper
|
||||||
int argc = 0;
|
int argc = 0;
|
||||||
char **argv = NULL;
|
char **argv = NULL;
|
||||||
// TODO the MPI_THREAD_MULTIPLE support is needed by project Multiverso.
|
// TODO the MPI_THREAD_MULTIPLE support is needed by project Multiverso.
|
||||||
// please make sure using the MSMPIv7 (or openmpi-1.8) and above.
|
// please make sure using the MSMPIv6 (or openmpi-1.8) and above.
|
||||||
int requiredThreadLevelSupport = MPI_THREAD_MULTIPLE;
|
int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED;
|
||||||
int provided;
|
int provided;
|
||||||
int ret = MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
|
int ret = MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
|
||||||
if (provided != requiredThreadLevelSupport)
|
if (provided != requiredThreadLevelSupport)
|
||||||
|
|
|
@ -51,7 +51,7 @@ namespace Microsoft {
|
||||||
public:
|
public:
|
||||||
MultiversoWrapper(const std::list<ComputationNodeBasePtr> & learnableNodes,
|
MultiversoWrapper(const std::list<ComputationNodeBasePtr> & learnableNodes,
|
||||||
int MPINodeNum,
|
int MPINodeNum,
|
||||||
bool isPipeline = true,
|
bool isAsyncBuffered = true,
|
||||||
AdjustLearningRateatBeginning adjusttype = AdjustLearningRateatBeginning::None,
|
AdjustLearningRateatBeginning adjusttype = AdjustLearningRateatBeginning::None,
|
||||||
double adjustcoef = 0.2,
|
double adjustcoef = 0.2,
|
||||||
size_t adjustnbmb = 600)
|
size_t adjustnbmb = 600)
|
||||||
|
@ -64,29 +64,25 @@ namespace Microsoft {
|
||||||
m_totalClientNumber = MPINodeNum;
|
m_totalClientNumber = MPINodeNum;
|
||||||
|
|
||||||
//Pipeline releated variables
|
//Pipeline releated variables
|
||||||
m_isPipelined = isPipeline;
|
m_isUseAsyncBuffered = isAsyncBuffered;
|
||||||
m_localCacheNumber = m_isPipelined ? 2 : 1;
|
m_localCacheNumber = m_isUseAsyncBuffered ? 2 : 1;
|
||||||
m_cacheSwapIndex = new int[m_localCacheNumber];
|
m_cacheSwapIndex = new int[m_localCacheNumber];
|
||||||
|
|
||||||
//CPU double buffer
|
//CPU asynchronous buffer
|
||||||
m_cpuAsyncBuffer = new ElemType*[m_localCacheNumber];
|
m_cpuAsyncBuffer = new ElemType*[m_localCacheNumber];
|
||||||
|
|
||||||
#ifndef CPUONLY
|
#ifndef CPUONLY
|
||||||
//GPU double buffer
|
//GPU asynchronous buffer
|
||||||
m_gpuAsyncBuffer = new Matrix<ElemType>**[m_localCacheNumber];
|
m_gpuAsyncBuffer = new Matrix<ElemType>**[m_localCacheNumber];
|
||||||
|
|
||||||
//Communication Stream
|
//creat an communication stream for the data tranfer between GPU and CPU
|
||||||
CudaErrorCheck(cudaStreamCreate(&_commStream));
|
CudaErrorCheck(cudaStreamCreate(&_commStream));
|
||||||
#endif
|
#endif
|
||||||
|
m_bufferInUse = 0;
|
||||||
m_cacheIndex = 0;
|
|
||||||
for (int i = 0; i < m_localCacheNumber; i++)
|
for (int i = 0; i < m_localCacheNumber; i++)
|
||||||
m_cacheSwapIndex[i] = (i + 1) % m_localCacheNumber;
|
m_cacheSwapIndex[i] = (i + 1) % m_localCacheNumber;
|
||||||
|
|
||||||
m_prefetchThread = new thread();
|
m_prefetchThread = nullptr;
|
||||||
|
|
||||||
m_modelSizeOfEachServer = new size_t[m_totalClientNumber];
|
|
||||||
m_indexOfEachServer = new size_t[m_totalClientNumber];
|
|
||||||
MultiversoInit(learnableNodes);
|
MultiversoInit(learnableNodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,10 +91,10 @@ namespace Microsoft {
|
||||||
fprintf(stderr, "~MultiversoWrapper\n");
|
fprintf(stderr, "~MultiversoWrapper\n");
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
|
|
||||||
if (m_isPipelined && m_prefetchThread != nullptr && m_prefetchThread->joinable())
|
if (m_isUseAsyncBuffered && m_prefetchThread != nullptr && m_prefetchThread->joinable())
|
||||||
m_prefetchThread->join();
|
m_prefetchThread->join();
|
||||||
|
|
||||||
delete m_cacheSwapIndex, m_deltaArray, m_modelSizeOfEachServer, m_indexOfEachServer;
|
delete m_cacheSwapIndex, m_deltaArray;
|
||||||
|
|
||||||
for (size_t i = 0; i < m_localCacheNumber; i++)
|
for (size_t i = 0; i < m_localCacheNumber; i++)
|
||||||
{
|
{
|
||||||
|
@ -115,13 +111,12 @@ namespace Microsoft {
|
||||||
multiverso::MultiversoShutDown(false);
|
multiverso::MultiversoShutDown(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function will upload parameters into Multiverso
|
// upoload preCompute model to the parameter servers
|
||||||
void InitModel(const std::list<ComputationNodeBasePtr> & learnableNodes)
|
void InitModel(const std::list<ComputationNodeBasePtr> & learnableNodes)
|
||||||
{
|
{
|
||||||
float factor = (float) 1.0 / m_totalClientNumber;
|
float factor = (float) 1.0 / m_totalClientNumber;
|
||||||
|
|
||||||
//weights
|
int i = 0; // indicate the index of learnable nodes
|
||||||
int i = 0;
|
|
||||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
||||||
{
|
{
|
||||||
ComputationNodePtr node = dynamic_pointer_cast<ComputationNode<ElemType>>(*nodeIter);
|
ComputationNodePtr node = dynamic_pointer_cast<ComputationNode<ElemType>>(*nodeIter);
|
||||||
|
@ -159,10 +154,10 @@ namespace Microsoft {
|
||||||
Timer timer;
|
Timer timer;
|
||||||
WaitAsyncBuffer();
|
WaitAsyncBuffer();
|
||||||
|
|
||||||
m_cacheIndex = m_cacheSwapIndex[m_cacheIndex];
|
m_bufferInUse = m_cacheSwapIndex[m_bufferInUse];
|
||||||
|
|
||||||
int i = 0;
|
int i = 0; // indicate the index of learnable nodes
|
||||||
if (m_isPipelined)
|
if (m_isUseAsyncBuffered)
|
||||||
{
|
{
|
||||||
|
|
||||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
||||||
|
@ -171,22 +166,22 @@ namespace Microsoft {
|
||||||
Microsoft::MSR::CNTK::Matrix<ElemType> &mat = node->Value();
|
Microsoft::MSR::CNTK::Matrix<ElemType> &mat = node->Value();
|
||||||
#ifndef CPUONLY
|
#ifndef CPUONLY
|
||||||
//CNTK model -> GPU buffer
|
//CNTK model -> GPU buffer
|
||||||
CudaErrorCheck(cudaMemcpy(m_gpuAsyncBuffer[m_cacheIndex][i]->BufferPointer(),
|
CudaErrorCheck(cudaMemcpy(m_gpuAsyncBuffer[m_bufferInUse][i]->BufferPointer(),
|
||||||
mat.BufferPointer(),
|
mat.BufferPointer(),
|
||||||
mat.GetNumElements() * sizeof(ElemType),
|
mat.GetNumElements() * sizeof(ElemType),
|
||||||
cudaMemcpyDeviceToDevice));
|
cudaMemcpyDeviceToDevice));
|
||||||
|
|
||||||
//GPU buffer -> CNTK model
|
//GPU buffer -> CNTK model
|
||||||
CudaErrorCheck(cudaMemcpy(mat.BufferPointer(),
|
CudaErrorCheck(cudaMemcpy(mat.BufferPointer(),
|
||||||
m_gpuAsyncBuffer[m_cacheSwapIndex[m_cacheIndex]][i]->BufferPointer(),
|
m_gpuAsyncBuffer[m_cacheSwapIndex[m_bufferInUse]][i]->BufferPointer(),
|
||||||
mat.GetNumElements() * sizeof(ElemType),
|
mat.GetNumElements() * sizeof(ElemType),
|
||||||
cudaMemcpyDeviceToDevice));
|
cudaMemcpyDeviceToDevice));
|
||||||
#else
|
#else
|
||||||
ElemType * px = m_cpuAsyncBuffer[m_cacheIndex] + m_tableIndex[i];
|
ElemType * px = m_cpuAsyncBuffer[m_bufferInUse] + m_tableIndex[i];
|
||||||
|
|
||||||
mat.CopyToArray(px, m_tableLength[i]);
|
mat.CopyToArray(px, m_tableLength[i]);
|
||||||
|
|
||||||
ElemType * py = m_cpuAsyncBuffer[m_cacheSwapIndex[m_cacheIndex]] + m_tableIndex[i];
|
ElemType * py = m_cpuAsyncBuffer[m_cacheSwapIndex[m_bufferInUse]] + m_tableIndex[i];
|
||||||
|
|
||||||
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), mat.GetDeviceId(), py);
|
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), mat.GetDeviceId(), py);
|
||||||
|
|
||||||
|
@ -197,7 +192,7 @@ namespace Microsoft {
|
||||||
#ifndef CPUONLY
|
#ifndef CPUONLY
|
||||||
m_prefetchThread = new thread([&](){
|
m_prefetchThread = new thread([&](){
|
||||||
float factor = DecayCoefficient();
|
float factor = DecayCoefficient();
|
||||||
int t_cacheIdx = m_cacheIndex;
|
int t_cacheIdx = m_bufferInUse;
|
||||||
int deviceId = m_gpuAsyncBuffer[t_cacheIdx][0]->GetDeviceId();
|
int deviceId = m_gpuAsyncBuffer[t_cacheIdx][0]->GetDeviceId();
|
||||||
|
|
||||||
CudaErrorCheck(cudaSetDevice(deviceId));
|
CudaErrorCheck(cudaSetDevice(deviceId));
|
||||||
|
@ -213,10 +208,10 @@ namespace Microsoft {
|
||||||
_commStream));
|
_commStream));
|
||||||
}
|
}
|
||||||
|
|
||||||
//Sync for copy
|
// waiting copy from GPU to CPU finished
|
||||||
CudaErrorCheck(cudaStreamSynchronize(_commStream));
|
CudaErrorCheck(cudaStreamSynchronize(_commStream));
|
||||||
|
|
||||||
//Calculate delta
|
// calculate delta
|
||||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
|
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
|
||||||
|
|
||||||
// lr decay
|
// lr decay
|
||||||
|
@ -224,9 +219,8 @@ namespace Microsoft {
|
||||||
|
|
||||||
m_sharedArray->Add(m_deltaArray, m_totalModelSize);
|
m_sharedArray->Add(m_deltaArray, m_totalModelSize);
|
||||||
m_sharedArray->Get(m_cpuAsyncBuffer[t_cacheIdx], m_totalModelSize);
|
m_sharedArray->Get(m_cpuAsyncBuffer[t_cacheIdx], m_totalModelSize);
|
||||||
//memcpy(m_cpuAsyncBuffer[t_cacheIdx], m_sharedArray->raw().data(), m_totalModelSize);
|
|
||||||
|
|
||||||
//CPU buffer -> GPU buffer
|
// copy parameters from CPU buffer to GPU buffer
|
||||||
for (int widx = 0; widx < m_tableCount; widx++)
|
for (int widx = 0; widx < m_tableCount; widx++)
|
||||||
{
|
{
|
||||||
ElemType * py = m_cpuAsyncBuffer[t_cacheIdx] + m_tableIndex[widx];
|
ElemType * py = m_cpuAsyncBuffer[t_cacheIdx] + m_tableIndex[widx];
|
||||||
|
@ -243,13 +237,13 @@ namespace Microsoft {
|
||||||
});
|
});
|
||||||
#else
|
#else
|
||||||
m_prefetchThread = new thread([&](){
|
m_prefetchThread = new thread([&](){
|
||||||
float factor = getUpdateCoefficient();
|
float factor = DecayCoefficient();
|
||||||
int table_id = 0, t_cacheIdx = m_cacheIndex;
|
int t_cacheIdx = m_bufferInUse;
|
||||||
|
|
||||||
transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
|
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
|
||||||
|
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
|
||||||
m_sharedArray->Add(m_deltaArray, m_totalModelSize);
|
m_sharedArray->Add(m_deltaArray, m_totalModelSize);
|
||||||
m_sharedArray->Get(m_cpuAsyncBuffer[t_cacheIdx], m_totalModelSize);
|
m_sharedArray->Get(m_cpuAsyncBuffer[t_cacheIdx], m_totalModelSize);
|
||||||
//memcpy(m_cpuAsyncBuffer[t_cacheIdx], m_sharedArray->raw().data(), m_totalModelSize);
|
|
||||||
|
|
||||||
});
|
});
|
||||||
#endif
|
#endif
|
||||||
|
@ -274,7 +268,6 @@ namespace Microsoft {
|
||||||
|
|
||||||
m_sharedArray->Add(m_deltaArray, m_totalModelSize);
|
m_sharedArray->Add(m_deltaArray, m_totalModelSize);
|
||||||
m_sharedArray->Get(m_cpuAsyncBuffer[0], m_totalModelSize);
|
m_sharedArray->Get(m_cpuAsyncBuffer[0], m_totalModelSize);
|
||||||
//memcpy(m_cpuAsyncBuffer[0], m_sharedArray->raw().data(), m_totalModelSize);
|
|
||||||
|
|
||||||
i = 0;
|
i = 0;
|
||||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
||||||
|
@ -283,7 +276,6 @@ namespace Microsoft {
|
||||||
Microsoft::MSR::CNTK::Matrix<ElemType> &mat = node->Value();
|
Microsoft::MSR::CNTK::Matrix<ElemType> &mat = node->Value();
|
||||||
|
|
||||||
ElemType * px = m_cpuAsyncBuffer[0] + m_tableIndex[i];
|
ElemType * px = m_cpuAsyncBuffer[0] + m_tableIndex[i];
|
||||||
|
|
||||||
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), mat.GetDeviceId(), px);
|
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), mat.GetDeviceId(), px);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -307,7 +299,11 @@ namespace Microsoft {
|
||||||
void WaitAsyncBuffer()
|
void WaitAsyncBuffer()
|
||||||
{
|
{
|
||||||
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
|
if (m_prefetchThread != nullptr && m_prefetchThread->joinable())
|
||||||
|
{
|
||||||
m_prefetchThread->join();
|
m_prefetchThread->join();
|
||||||
|
delete m_prefetchThread;
|
||||||
|
m_prefetchThread = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
void MultiversoInit(const std::list<ComputationNodeBasePtr> & learnableNodes)
|
void MultiversoInit(const std::list<ComputationNodeBasePtr> & learnableNodes)
|
||||||
|
@ -315,8 +311,8 @@ namespace Microsoft {
|
||||||
assert(!m_isInitialized);
|
assert(!m_isInitialized);
|
||||||
m_isInitialized = true;
|
m_isInitialized = true;
|
||||||
|
|
||||||
multiverso::MultiversoInit();
|
|
||||||
//multiverso::Log::ResetLogLevel(multiverso::LogLevel::Debug);
|
//multiverso::Log::ResetLogLevel(multiverso::LogLevel::Debug);
|
||||||
|
multiverso::MultiversoInit();
|
||||||
|
|
||||||
//weights
|
//weights
|
||||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
|
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
|
||||||
|
@ -330,22 +326,15 @@ namespace Microsoft {
|
||||||
|
|
||||||
m_tableCount = m_tableLength.size();
|
m_tableCount = m_tableLength.size();
|
||||||
|
|
||||||
//init cache space.
|
// cacluate total of learnable node's size
|
||||||
m_totalModelSize = accumulate(m_tableLength.begin(), m_tableLength.end(), 0);
|
m_totalModelSize = accumulate(m_tableLength.begin(), m_tableLength.end(), 0);
|
||||||
size_t idx = 0;
|
|
||||||
|
|
||||||
//for (int i = 0; i < m_totalClientNumber; i++)
|
|
||||||
//{
|
|
||||||
// m_indexOfEachServer[i] = idx;
|
|
||||||
// m_modelSizeOfEachServer[i] = i < m_totalModelSize % m_totalClientNumber ? m_totalModelSize / m_totalClientNumber + 1 : m_totalModelSize / m_totalClientNumber;
|
|
||||||
// idx += m_modelSizeOfEachServer[i];
|
|
||||||
//}
|
|
||||||
|
|
||||||
m_sharedArray = new multiverso::ArrayWorker<ElemType>(m_totalModelSize);
|
m_sharedArray = new multiverso::ArrayWorker<ElemType>(m_totalModelSize);
|
||||||
m_serverArray = new multiverso::ArrayServer<ElemType>(m_totalModelSize);
|
m_serverArray = new multiverso::ArrayServer<ElemType>(m_totalModelSize);
|
||||||
|
|
||||||
multiverso::MultiversoBarrier();
|
multiverso::MultiversoBarrier();
|
||||||
idx = 0;
|
|
||||||
|
size_t idx = 0;
|
||||||
for (size_t len : m_tableLength)
|
for (size_t len : m_tableLength)
|
||||||
{
|
{
|
||||||
m_tableIndex.push_back(idx);
|
m_tableIndex.push_back(idx);
|
||||||
|
@ -353,18 +342,17 @@ namespace Microsoft {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef CPUONLY
|
#ifndef CPUONLY
|
||||||
//pinned memory
|
|
||||||
for (int i = 0; i < m_localCacheNumber; ++i)
|
|
||||||
CudaErrorCheck(cudaMallocHost((void **)&m_cpuAsyncBuffer[i], sizeof(ElemType) * (m_totalModelSize + 1), cudaHostAllocPortable));
|
|
||||||
|
|
||||||
CudaErrorCheck(cudaMallocHost((void **)&m_deltaArray, sizeof(ElemType) * (m_totalModelSize + 1), cudaHostAllocPortable));
|
|
||||||
|
|
||||||
//GPU memory cache
|
|
||||||
for (int i = 0; i < m_localCacheNumber; i++)
|
for (int i = 0; i < m_localCacheNumber; i++)
|
||||||
m_gpuAsyncBuffer[i] = new Matrix<ElemType>*[m_tableCount];
|
m_gpuAsyncBuffer[i] = new Matrix<ElemType>*[m_tableCount];
|
||||||
|
|
||||||
|
//create pinned memory
|
||||||
|
for (int i = 0; i < m_localCacheNumber; ++i)
|
||||||
|
CudaErrorCheck(cudaMallocHost((void **)&m_cpuAsyncBuffer[i], sizeof(ElemType) * (m_totalModelSize), cudaHostAllocPortable));
|
||||||
|
|
||||||
|
CudaErrorCheck(cudaMallocHost((void **)&m_deltaArray, sizeof(ElemType) * (m_totalModelSize), cudaHostAllocPortable));
|
||||||
#else
|
#else
|
||||||
for (int i = 0; i < m_localCacheNumber; i++)
|
for (int i = 0; i < m_localCacheNumber; i++)
|
||||||
m_cpuAsyncBuffer[i] = new ElemType[m_totalModelSize + 1];
|
m_cpuAsyncBuffer[i] = new ElemType[m_totalModelSize];
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -393,10 +381,10 @@ namespace Microsoft {
|
||||||
|
|
||||||
int m_totalClientNumber;
|
int m_totalClientNumber;
|
||||||
|
|
||||||
bool m_isPipelined;
|
bool m_isUseAsyncBuffered;
|
||||||
int m_localCacheNumber;
|
int m_localCacheNumber;
|
||||||
int * m_cacheSwapIndex;
|
int * m_cacheSwapIndex;
|
||||||
int m_cacheIndex;
|
int m_bufferInUse;
|
||||||
|
|
||||||
size_t m_modelSyncCount;
|
size_t m_modelSyncCount;
|
||||||
|
|
||||||
|
@ -410,10 +398,6 @@ namespace Microsoft {
|
||||||
ElemType * m_deltaArray;
|
ElemType * m_deltaArray;
|
||||||
ElemType ** m_cpuAsyncBuffer;
|
ElemType ** m_cpuAsyncBuffer;
|
||||||
|
|
||||||
// TODO deprecated this unused variables
|
|
||||||
size_t * m_modelSizeOfEachServer;
|
|
||||||
size_t * m_indexOfEachServer;
|
|
||||||
|
|
||||||
//GPU double buffer
|
//GPU double buffer
|
||||||
Matrix<ElemType> *** m_gpuAsyncBuffer;
|
Matrix<ElemType> *** m_gpuAsyncBuffer;
|
||||||
int m_tableCount;
|
int m_tableCount;
|
||||||
|
|
|
@ -319,19 +319,20 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR);
|
m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Multiverso Warpper for ASGD logic init
|
//Multiverso Warpper for ASGD logic init
|
||||||
if (m_parallelizationMethod == ParallelizationMethod::DataParallelASGD)
|
if (m_parallelizationMethod == ParallelizationMethod::DataParallelASGD)
|
||||||
{
|
{
|
||||||
m_multiverso = new MultiversoWrapper<ElemType>(learnableNodes,
|
g_mpi->WaitAll();
|
||||||
g_mpi->NumNodesInUse(),
|
m_multiverso = new MultiversoWrapper<ElemType>(learnableNodes,
|
||||||
m_isPipeline,
|
g_mpi->NumNodesInUse(),
|
||||||
m_adjustlearningrateatbeginning,
|
m_isPipeline,
|
||||||
m_adjustcoefficient,
|
m_adjustlearningrateatbeginning,
|
||||||
m_adjustnbminibatch);
|
m_adjustcoefficient,
|
||||||
m_multiverso->InitModel(learnableNodes);
|
m_adjustnbminibatch);
|
||||||
m_multiversoBarrier = false;
|
m_multiverso->InitModel(learnableNodes);
|
||||||
m_multiverso->WaitAll();
|
m_multiversoBarrier = false;
|
||||||
}
|
m_multiverso->WaitAll();
|
||||||
|
}
|
||||||
|
|
||||||
// --- MAIN EPOCH LOOP
|
// --- MAIN EPOCH LOOP
|
||||||
for (int i = startEpoch; i < (int) m_maxEpochs; i++) // TODO: why is this an int, and not a size_t?
|
for (int i = startEpoch; i < (int) m_maxEpochs; i++) // TODO: why is this an int, and not a size_t?
|
||||||
|
@ -343,9 +344,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
g_mpi->WaitAll();
|
g_mpi->WaitAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_parallelizationMethod == ParallelizationMethod::DataParallelASGD && m_nEpochBarrier > 0 && i % m_nEpochBarrier == 0)
|
if (m_parallelizationMethod == ParallelizationMethod::DataParallelASGD && m_nEpochBarrier[i] > 0 && i % m_nEpochBarrier[i] == 0)
|
||||||
{
|
{
|
||||||
m_multiverso->WaitAsyncBuffer(); // [Review:qiwye] does
|
m_multiverso->WaitAsyncBuffer();
|
||||||
m_multiverso->WaitAll();
|
m_multiverso->WaitAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -701,11 +702,11 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delete inputMatrices;
|
delete inputMatrices;
|
||||||
if (m_parallelizationMethod == ParallelizationMethod::DataParallelASGD)
|
if (m_parallelizationMethod == ParallelizationMethod::DataParallelASGD)
|
||||||
{
|
{
|
||||||
delete m_multiverso;
|
delete m_multiverso;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
@ -759,9 +760,9 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
(epochNumber >= m_parallelizationStartEpochNum));
|
(epochNumber >= m_parallelizationStartEpochNum));
|
||||||
bool useModelAveraging = ((m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) &&
|
bool useModelAveraging = ((m_parallelizationMethod == ParallelizationMethod::ModelAveragingSGD) &&
|
||||||
(epochNumber >= m_parallelizationStartEpochNum));
|
(epochNumber >= m_parallelizationStartEpochNum));
|
||||||
bool useASGD = ((m_parallelizationMethod == ParallelizationMethod::DataParallelASGD) &&
|
bool useASGD = ((m_parallelizationMethod == ParallelizationMethod::DataParallelASGD) &&
|
||||||
(epochNumber >= m_parallelizationStartEpochNum));
|
(epochNumber >= m_parallelizationStartEpochNum));
|
||||||
bool useParallelTrain = useGradientAggregation || useModelAveraging || useASGD;
|
bool useParallelTrain = useGradientAggregation || useModelAveraging || useASGD;
|
||||||
|
|
||||||
// MA-related variables
|
// MA-related variables
|
||||||
size_t nSamplesSinceLastModelSync = 0;
|
size_t nSamplesSinceLastModelSync = 0;
|
||||||
|
@ -872,12 +873,6 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, criterionNodes[0],
|
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, criterionNodes[0],
|
||||||
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize);
|
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize);
|
||||||
|
|
||||||
if (!m_multiversoBarrier && useASGD)
|
|
||||||
{
|
|
||||||
m_multiverso->WaitAll();
|
|
||||||
m_multiversoBarrier = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
|
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -1103,28 +1098,28 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
}
|
}
|
||||||
aggregateNumSamplesWithLabel = processedSamples;
|
aggregateNumSamplesWithLabel = processedSamples;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (useASGD && g_mpi->NumNodesInUse() > 1)
|
if (useASGD && g_mpi->NumNodesInUse() > 1)
|
||||||
|
{
|
||||||
|
// Determine if any samples were processed across any of the ranks
|
||||||
|
if (useDistributedMBReading)
|
||||||
{
|
{
|
||||||
// Determine if any samples were processed across any of the ranks
|
noMoreSamplesToProcess = !wasDataRead;
|
||||||
if (useDistributedMBReading)
|
|
||||||
{
|
|
||||||
noMoreSamplesToProcess = !wasDataRead;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t processedSamples = 0;
|
|
||||||
if (nSamplesSinceLastModelSync >= m_nFramesBetweenASGDSync)
|
|
||||||
{
|
|
||||||
m_multiverso->PushAndPullModel(learnableNodes);
|
|
||||||
processedSamples = nSamplesSinceLastModelSync;
|
|
||||||
nSamplesSinceLastModelSync = 0;
|
|
||||||
}
|
|
||||||
aggregateNumSamplesWithLabel = processedSamples;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
commTimer.Stop();
|
size_t processedSamples = 0;
|
||||||
commTime += commTimer.ElapsedSeconds();
|
if (nSamplesSinceLastModelSync >= m_nFramesBetweenASGDSync[epochNumber])
|
||||||
|
{
|
||||||
|
m_multiverso->PushAndPullModel(learnableNodes);
|
||||||
|
processedSamples = nSamplesSinceLastModelSync;
|
||||||
|
nSamplesSinceLastModelSync = 0;
|
||||||
|
}
|
||||||
|
aggregateNumSamplesWithLabel = processedSamples;
|
||||||
|
}
|
||||||
|
|
||||||
|
commTimer.Stop();
|
||||||
|
commTime += commTimer.ElapsedSeconds();
|
||||||
|
|
||||||
timer.Stop();
|
timer.Stop();
|
||||||
numMBsRun++;
|
numMBsRun++;
|
||||||
|
@ -1262,15 +1257,15 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
nSamplesSinceLastModelSync = 0;
|
nSamplesSinceLastModelSync = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (useASGD && (g_mpi->NumNodesInUse() > 1))
|
if (useASGD && (g_mpi->NumNodesInUse() > 1))
|
||||||
{
|
{
|
||||||
// ASGD also may not be synced after epoch finished, so do the sync here
|
// ASGD also may not be synced after epoch finished, so do the sync here
|
||||||
int residualSampels = (int)nSamplesSinceLastModelSync;
|
int residualSampels = (int)nSamplesSinceLastModelSync;
|
||||||
totalSamplesSeen += residualSampels;
|
totalSamplesSeen += residualSampels;
|
||||||
totalEpochSamples += residualSampels;
|
totalEpochSamples += residualSampels;
|
||||||
m_multiverso->PushAndPullModel(learnableNodes);
|
m_multiverso->PushAndPullModel(learnableNodes);
|
||||||
nSamplesSinceLastModelSync = 0;
|
nSamplesSinceLastModelSync = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute final criterion values
|
// compute final criterion values
|
||||||
if (useGradientAggregation)
|
if (useGradientAggregation)
|
||||||
|
@ -2465,17 +2460,13 @@ static LearningRateSearchAlgorithm ParseLearningRateSearchType(const wstring& s)
|
||||||
else InvalidArgument("autoAdjustLR: Invalid learning rate search type. Valid values are (none | searchBeforeEpoch | adjustAfterEpoch)");
|
else InvalidArgument("autoAdjustLR: Invalid learning rate search type. Valid values are (none | searchBeforeEpoch | adjustAfterEpoch)");
|
||||||
}
|
}
|
||||||
|
|
||||||
static AdjustLearningRateatBeginning AdjustLearningRateAtBeginningType(wstring s)
|
static AdjustLearningRateatBeginning AdjustLearningRateAtBeginningType(wstring s)
|
||||||
{
|
{
|
||||||
if (!_wcsicmp(s.c_str(), L"") || !_wcsicmp(s.c_str(), L"none"))
|
if (EqualCI(s.c_str(), L"") || EqualCI(s.c_str(), L"none")) return AdjustLearningRateatBeginning::None;
|
||||||
return AdjustLearningRateatBeginning::None;
|
else if (EqualCI(s.c_str(), L"linearly")) return AdjustLearningRateatBeginning::Linearly;
|
||||||
else if (!_wcsicmp(s.c_str(), L"linearly"))
|
else if (EqualCI(s.c_str(), L"staircase")) return AdjustLearningRateatBeginning::Staircase;
|
||||||
return AdjustLearningRateatBeginning::Linearly;
|
else InvalidArgument("AdjustLearningRateatBeginningType: Invalid Type. Valid values are (None | Linearly | Staircase)");
|
||||||
else if (!_wcsicmp(s.c_str(), L"staircase"))
|
}
|
||||||
return AdjustLearningRateatBeginning::Staircase;
|
|
||||||
else
|
|
||||||
InvalidArgument("AdjustLearningRateatBeginningType: Invalid Type. Valid values are (None | Linearly | Staircase)");
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class ConfigRecordType>
|
template<class ConfigRecordType>
|
||||||
SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
||||||
|
@ -2655,37 +2646,37 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
||||||
m_momentumParam = momentumPerSampleVec;
|
m_momentumParam = momentumPerSampleVec;
|
||||||
m_momentumSpecifiedForMBSize = intargvector(L"1");
|
m_momentumSpecifiedForMBSize = intargvector(L"1");
|
||||||
}
|
}
|
||||||
else if (momentumPerMB.size() > 0)
|
else if (momentumPerMB.size() > 0)
|
||||||
{
|
{
|
||||||
m_momentumParam = momentumPerMB;
|
m_momentumParam = momentumPerMB;
|
||||||
m_momentumSpecifiedForMBSize = m_mbSize;
|
m_momentumSpecifiedForMBSize = m_mbSize;
|
||||||
}
|
}
|
||||||
else // default: momentumPerMB = 0.9 per MB
|
else // default: momentumPerMB = 0.9 per MB
|
||||||
{
|
{
|
||||||
m_momentumParam = floatargvector(L"0.9");
|
m_momentumParam = floatargvector(L"0.9");
|
||||||
m_momentumSpecifiedForMBSize = m_mbSize;
|
m_momentumSpecifiedForMBSize = m_mbSize;
|
||||||
}
|
}
|
||||||
m_useNesterovMomentum = useNesterovMomentum;
|
m_useNesterovMomentum = useNesterovMomentum;
|
||||||
|
|
||||||
for (int i = 0; i < m_momentumParam.size(); i++)
|
for (int i = 0; i < m_momentumParam.size(); i++)
|
||||||
{
|
{
|
||||||
if (m_momentumParam[i] >= 1.0 || m_momentumParam[i] < 0.0)
|
if (m_momentumParam[i] >= 1.0 || m_momentumParam[i] < 0.0)
|
||||||
{
|
{
|
||||||
InvalidArgument("Momentum parameter must be in [0, 1).");
|
InvalidArgument("Momentum parameter must be in [0, 1).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_learnRateDecreaseFactor > 1 || m_learnRateIncreaseFactor < 1)
|
if (m_learnRateDecreaseFactor > 1 || m_learnRateIncreaseFactor < 1)
|
||||||
{
|
{
|
||||||
InvalidArgument("learnRateIncreaseFactor must be >= 1 and learnRateDecreaseFactor must be <= 1.");
|
InvalidArgument("learnRateIncreaseFactor must be >= 1 and learnRateDecreaseFactor must be <= 1.");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < m_dropoutRates.size(); i++)
|
for (size_t i = 0; i < m_dropoutRates.size(); i++)
|
||||||
{
|
{
|
||||||
if (m_dropoutRates[i] >= 1 || m_dropoutRates[i] < 0)
|
if (m_dropoutRates[i] >= 1 || m_dropoutRates[i] < 0)
|
||||||
{
|
{
|
||||||
InvalidArgument("dropoutRate must be >= 0 and < 1.");
|
InvalidArgument("dropoutRate must be >= 0 and < 1.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_adaptationRegWeight > 1 || m_adaptationRegWeight < 0)
|
if (m_adaptationRegWeight > 1 || m_adaptationRegWeight < 0)
|
||||||
|
@ -2707,66 +2698,63 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
||||||
m_parallelizationStartEpochNum = 0;
|
m_parallelizationStartEpochNum = 0;
|
||||||
m_nFramesBetweenMASync = 40000; // default 40k frames
|
m_nFramesBetweenMASync = 40000; // default 40k frames
|
||||||
|
|
||||||
m_nFramesBetweenASGDSync = 1280;
|
|
||||||
m_numMBsToASGDPushAndPull = 0;
|
|
||||||
m_nEpochBarrier = 0;
|
|
||||||
m_adjustlearningrateatbeginning = AdjustLearningRateatBeginning::None;
|
m_adjustlearningrateatbeginning = AdjustLearningRateatBeginning::None;
|
||||||
|
|
||||||
|
|
||||||
if ((g_mpi != nullptr) && configSGD.Exists(L"ParallelTrain"))
|
if ((g_mpi != nullptr) && configSGD.Exists(L"ParallelTrain"))
|
||||||
{
|
{
|
||||||
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
|
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
|
||||||
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));
|
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));
|
||||||
m_parallelizationStartEpochNum = configParallelTrain(L"parallelizationStartEpoch", (int) 1) - 1; // Epoch numbers internally are 0 based
|
m_parallelizationStartEpochNum = configParallelTrain(L"parallelizationStartEpoch", (int)1) - 1; // Epoch numbers internally are 0 based
|
||||||
m_enableDistributedMBReading = configParallelTrain(L"distributedMBReading", false);
|
m_enableDistributedMBReading = configParallelTrain(L"distributedMBReading", false);
|
||||||
m_syncStatsTrace = configParallelTrain(L"syncPerfStats", (int) 0);
|
m_syncStatsTrace = configParallelTrain(L"syncPerfStats", (int)0);
|
||||||
|
|
||||||
if (configParallelTrain.Exists(L"DataParallelSGD"))
|
if (configParallelTrain.Exists(L"DataParallelSGD"))
|
||||||
{
|
{
|
||||||
const ConfigRecordType& configDataParallelSGD(configParallelTrain(L"DataParallelSGD", ConfigRecordType::Record()));
|
const ConfigRecordType& configDataParallelSGD(configParallelTrain(L"DataParallelSGD", ConfigRecordType::Record()));
|
||||||
size_t defaultGradientBits = 8 * sizeofElemType;
|
size_t defaultGradientBits = 8 * sizeofElemType;
|
||||||
m_numGradientBits = configDataParallelSGD(L"gradientBits", defaultGradientBits);
|
m_numGradientBits = configDataParallelSGD(L"gradientBits", defaultGradientBits);
|
||||||
m_zeroThresholdFor1Bit = configDataParallelSGD(L"useZeroThresholdFor1BitQuantization", true);
|
m_zeroThresholdFor1Bit = configDataParallelSGD(L"useZeroThresholdFor1BitQuantization", true);
|
||||||
m_bufferedAsyncGradientAggregation = configDataParallelSGD(L"useBufferedAsyncGradientAggregation", false);
|
m_bufferedAsyncGradientAggregation = configDataParallelSGD(L"useBufferedAsyncGradientAggregation", false);
|
||||||
if ((m_numGradientBits < 1) || (m_numGradientBits > (8 * sizeofElemType)))
|
if ((m_numGradientBits < 1) || (m_numGradientBits >(8 * sizeofElemType)))
|
||||||
{
|
|
||||||
InvalidArgument("gradientBits must be in the range [1, 32] when using precision=float and in range [1, 64] when using precision=double!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (configParallelTrain.Exists(L"ModelAveragingSGD"))
|
|
||||||
{
|
|
||||||
const ConfigRecordType& configMASGD(configParallelTrain(L"ModelAveragingSGD", ConfigRecordType::Record()));
|
|
||||||
m_nFramesBetweenMASync = configMASGD(L"syncFrequencyInFrames", (size_t) 40000);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (configParallelTrain.Exists(L"DataParallelASGD"))
|
|
||||||
{
|
{
|
||||||
const ConfigRecordType & configDataParallelASGD(configParallelTrain(L"DataParallelASGD", ConfigRecordType::Record()));
|
InvalidArgument("gradientBits must be in the range [1, 32] when using precision=float and in range [1, 64] when using precision=double!");
|
||||||
m_nFramesBetweenASGDSync = configDataParallelASGD(L"SyncFrequencyInFrames", (size_t)1280);
|
|
||||||
m_isPipeline = configDataParallelASGD(L"UsePipeline", true);
|
|
||||||
m_nEpochBarrier = configDataParallelASGD(L"EpochBarrier", (size_t)0);
|
|
||||||
if (configDataParallelASGD.Exists(L"AdjustLearningRateAtBeginning"))
|
|
||||||
{
|
|
||||||
const ConfigRecordType & configAdjustLearningRateAtBeginning(configDataParallelASGD(L"AdjustLearningRateAtBeginning", ConfigRecordType::Record()));
|
|
||||||
m_adjustlearningrateatbeginning = AdjustLearningRateAtBeginningType(configAdjustLearningRateAtBeginning(L"adjustType", L"None"));
|
|
||||||
m_adjustcoefficient = configAdjustLearningRateAtBeginning(L"adjustCoefficient", (double)0.2);
|
|
||||||
m_adjustnbminibatch = configAdjustLearningRateAtBeginning(L"adjustNbMinibatch", (size_t)600);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (configParallelTrain.Exists(L"ModelAveragingSGD"))
|
||||||
|
{
|
||||||
|
const ConfigRecordType& configMASGD(configParallelTrain(L"ModelAveragingSGD", ConfigRecordType::Record()));
|
||||||
|
m_nFramesBetweenMASync = configMASGD(L"syncFrequencyInFrames", (size_t)40000);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (configParallelTrain.Exists(L"DataParallelASGD"))
|
||||||
|
{
|
||||||
|
const ConfigRecordType & configDataParallelASGD(configParallelTrain(L"DataParallelASGD", ConfigRecordType::Record()));
|
||||||
|
m_nFramesBetweenASGDSync = configDataParallelASGD(L"SyncFrequencyInFrames", ConfigRecordType::Array(intargvector(vector<int>{1280})));
|
||||||
|
m_isPipeline = configDataParallelASGD(L"UsePipeline", true);
|
||||||
|
m_nEpochBarrier = configDataParallelASGD(L"EpochBarrier", ConfigRecordType::Array(intargvector(vector<int>{0})));
|
||||||
|
if (configDataParallelASGD.Exists(L"AdjustLearningRateAtBeginning"))
|
||||||
|
{
|
||||||
|
const ConfigRecordType & configAdjustLearningRateAtBeginning(configDataParallelASGD(L"AdjustLearningRateAtBeginning", ConfigRecordType::Record()));
|
||||||
|
m_adjustlearningrateatbeginning = AdjustLearningRateAtBeginningType(configAdjustLearningRateAtBeginning(L"adjustType", L"None"));
|
||||||
|
m_adjustcoefficient = configAdjustLearningRateAtBeginning(L"adjustCoefficient", (double)0.2);
|
||||||
|
m_adjustnbminibatch = configAdjustLearningRateAtBeginning(L"adjustNbMinibatch", (size_t)600);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t GetSizeOfPrecision(const ScriptableObjects::IConfigRecordPtr configp)
|
static size_t GetSizeOfPrecision(const ScriptableObjects::IConfigRecordPtr configp)
|
||||||
{
|
{
|
||||||
wstring precision = configp->Get(L"precision");
|
wstring precision = configp->Get(L"precision");
|
||||||
if (precision == L"float")
|
if (precision == L"float")
|
||||||
return sizeof(float);
|
return sizeof(float);
|
||||||
else if (precision == L"double")
|
else if (precision == L"double")
|
||||||
return sizeof(double);
|
return sizeof(double);
|
||||||
else
|
else
|
||||||
RuntimeError("invalid value '%ls' for 'precision', must be 'float' or 'double'", precision.c_str());
|
RuntimeError("invalid value '%ls' for 'precision', must be 'float' or 'double'", precision.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
SGDParams::SGDParams(const ScriptableObjects::IConfigRecordPtr configp)
|
SGDParams::SGDParams(const ScriptableObjects::IConfigRecordPtr configp)
|
||||||
|
|
|
@ -249,11 +249,11 @@ protected:
|
||||||
double m_L1RegWeight;
|
double m_L1RegWeight;
|
||||||
|
|
||||||
// Parallel training related with ASGD
|
// Parallel training related with ASGD
|
||||||
size_t m_numMBsToASGDPushAndPull; // decide how many minibatchs should ASGD to a pull&push to parameter server.
|
intargvector m_numMBsToASGDPushAndPull; // decide how many minibatchs should ASGD to a pull&push to parameter server.
|
||||||
// note that, this will override m_nFramesBetweenASGDSync when set.
|
// note that, this will override m_nFramesBetweenASGDSync when set.
|
||||||
size_t m_nFramesBetweenASGDSync;
|
intargvector m_nFramesBetweenASGDSync;
|
||||||
bool m_isPipeline;
|
bool m_isPipeline;
|
||||||
size_t m_nEpochBarrier;
|
intargvector m_nEpochBarrier;
|
||||||
AdjustLearningRateatBeginning m_adjustlearningrateatbeginning;
|
AdjustLearningRateatBeginning m_adjustlearningrateatbeginning;
|
||||||
double m_adjustcoefficient;
|
double m_adjustcoefficient;
|
||||||
size_t m_adjustnbminibatch;
|
size_t m_adjustnbminibatch;
|
||||||
|
|
Загрузка…
Ссылка в новой задаче