This commit is contained in:
Qiwei Ye 2016-04-13 16:32:26 +08:00
Родитель 8fc2d2135d
Коммит 51dfd58cad
2 изменённых файлов: 7 добавлений и 10 удалений

Просмотреть файл

@ -4,6 +4,7 @@
// the header files in ..\Multiverso\include
#include <multiverso/multiverso.h>
#include <multiverso/table/matrix_table.h>
#include <multiverso/util/configure.h>
#pragma comment(lib, "IMultiverso.lib")
#ifndef CPUONLY
@ -225,8 +226,8 @@ namespace Microsoft {
// waiting copy from GPU to CPU has finished
CudaErrorCheck(cudaStreamSynchronize(_commStream));
// calculate delta
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
// delta = gradient * learning_rate
std::transform(m_cpuAsyncBuffer[t_cacheIdx], m_cpuAsyncBuffer[t_cacheIdx] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
// lr decay
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
@ -239,9 +240,6 @@ namespace Microsoft {
multiversoMatrix->Get(py, m_tableLength[widx]);
}
//m_matrixArray->Add(m_deltaArray, m_totalModelSize);
//m_matrixArray->Get(m_cpuAsyncBuffer[t_cacheIdx], m_totalModelSize);
// copy parameters from CPU buffer to GPU buffer
for (int widx = 0; widx < m_tableCount; widx++)
{
@ -262,7 +260,7 @@ namespace Microsoft {
float factor = DecayCoefficient();
int t_cacheIdx = m_bufferInUse;
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
std::transform(m_cpuAsyncBuffer[t_cacheIdx], m_cpuAsyncBuffer[t_cacheIdx] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
for (int widx = 0; widx < m_tableCount; widx++)
{
@ -291,7 +289,7 @@ namespace Microsoft {
mat.CopyToArray(px, m_tableLength[i]);
}
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[0], m_deltaArray, std::minus<ElemType>());
std::transform(m_cpuAsyncBuffer[0], m_cpuAsyncBuffer[0] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
// lr decay
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
@ -303,8 +301,6 @@ namespace Microsoft {
multiversoMatrix->Add(px, m_tableLength[widx]);
multiversoMatrix->Get(py, m_tableLength[widx]);
}
//m_matrixArray->Add(m_deltaArray, m_totalModelSize);
//m_matrixArray->Get(m_cpuAsyncBuffer[0], m_totalModelSize);
i = 0;
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
@ -349,6 +345,7 @@ namespace Microsoft {
m_isInitialized = true;
multiverso::MV_Init();
multiverso::SetCMDFlag<std::string>(std::string("updater_type"), std::string("sgd"));
m_matrixArray = new std::vector< multiverso::MatrixWorkerTable<ElemType>*>();
m_serverArray = new std::vector< multiverso::MatrixServerTable<ElemType>*>();

@ -1 +1 @@
Subproject commit 6f6cf3d44407cfdb15a4152b7d73faef7eafea29
Subproject commit f9217b4909609425db9bb5a5889b20357fc14fce