This commit is contained in:
Qiwei Ye 2016-06-15 20:12:42 +08:00
Родитель d8b1a25d71
Коммит 88883aebfa
2 изменённых файлов: 23 добавлений и 4 удалений

Просмотреть файл

@ -26,8 +26,8 @@
#include <numeric>
#include <algorithm>
#define MULTIVERSO_DEBUG
namespace Microsoft { namespace MSR { namespace CNTK {
#define MULTIVERSO_DEBUG
#ifndef CPUONLY
#define CudaErrorCheck(ans) { gpuAssert((ans), __FILE__, __LINE__); }
@ -343,6 +343,7 @@ bool PushAndPullModel(const std::list<ComputationNodeBasePtr> & learnableNodes,
}
else
{
timer.Restart();
float factor = DecayCoefficient();
i = 0;
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
@ -354,6 +355,12 @@ bool PushAndPullModel(const std::list<ComputationNodeBasePtr> & learnableNodes,
mat.CopyToArray(px, m_tableLength[i]);
}
timer.Stop();
if (m_traceLevel > 3)
{
double time = timer.ElapsedSeconds();
fprintf(stderr, "\t\t -- pullAndRequest, GPU -> CPU time %lf \n", time);
}
std::transform(m_cpuAsyncBuffer[0], m_cpuAsyncBuffer[0] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
// lr decay
@ -374,7 +381,7 @@ bool PushAndPullModel(const std::list<ComputationNodeBasePtr> & learnableNodes,
{
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
}
timer.Restart();
for (int widx = 0; widx < m_tableCount; widx++)
{
if (m_isSparseArray[widx])
@ -394,7 +401,13 @@ bool PushAndPullModel(const std::list<ComputationNodeBasePtr> & learnableNodes,
multiversoMatrix->Get(py, m_tableLength[widx]);
}
}
timer.Stop();
if (m_traceLevel > 3)
{
double time = timer.ElapsedSeconds();
fprintf(stderr, "\t\t -- pullAndRequest, Worker <--> Multiverso time %lf \n", time);
}
timer.Restart();
i = 0;
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
{
@ -404,6 +417,12 @@ bool PushAndPullModel(const std::list<ComputationNodeBasePtr> & learnableNodes,
ElemType * px = m_cpuAsyncBuffer[0] + m_tableOffsets[i];
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), mat.GetDeviceId(), px);
}
timer.Stop();
if (m_traceLevel > 3)
{
double time = timer.ElapsedSeconds();
fprintf(stderr, "\t\t -- pullAndRequest, CPU -> GPU time %lf \n", time);
}
}
return true;
}

@ -1 +1 @@
Subproject commit 31803209699e4454ad2f1b705ae84fa064fb508c
Subproject commit de4fb5c033f6ead87940a2e7a3c2e872b9f38692