sparse update support ready
This commit is contained in:
Родитель
8fc2d2135d
Коммит
51dfd58cad
|
@ -4,6 +4,7 @@
|
|||
// the header files in ..\Multiverso\include
|
||||
#include <multiverso/multiverso.h>
|
||||
#include <multiverso/table/matrix_table.h>
|
||||
#include <multiverso/util/configure.h>
|
||||
#pragma comment(lib, "IMultiverso.lib")
|
||||
|
||||
#ifndef CPUONLY
|
||||
|
@ -225,8 +226,8 @@ namespace Microsoft {
|
|||
// waiting copy from GPU to CPU has finished
|
||||
CudaErrorCheck(cudaStreamSynchronize(_commStream));
|
||||
|
||||
// calculate delta
|
||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
|
||||
// delta = gradient * learning_rate
|
||||
std::transform(m_cpuAsyncBuffer[t_cacheIdx], m_cpuAsyncBuffer[t_cacheIdx] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
|
||||
|
||||
// lr decay
|
||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
|
||||
|
@ -239,9 +240,6 @@ namespace Microsoft {
|
|||
multiversoMatrix->Get(py, m_tableLength[widx]);
|
||||
}
|
||||
|
||||
//m_matrixArray->Add(m_deltaArray, m_totalModelSize);
|
||||
//m_matrixArray->Get(m_cpuAsyncBuffer[t_cacheIdx], m_totalModelSize);
|
||||
|
||||
// copy parameters from CPU buffer to GPU buffer
|
||||
for (int widx = 0; widx < m_tableCount; widx++)
|
||||
{
|
||||
|
@ -262,7 +260,7 @@ namespace Microsoft {
|
|||
float factor = DecayCoefficient();
|
||||
int t_cacheIdx = m_bufferInUse;
|
||||
|
||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[t_cacheIdx], m_deltaArray, std::minus<ElemType>());
|
||||
std::transform(m_cpuAsyncBuffer[t_cacheIdx], m_cpuAsyncBuffer[t_cacheIdx] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
|
||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
|
||||
for (int widx = 0; widx < m_tableCount; widx++)
|
||||
{
|
||||
|
@ -291,7 +289,7 @@ namespace Microsoft {
|
|||
mat.CopyToArray(px, m_tableLength[i]);
|
||||
}
|
||||
|
||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_cpuAsyncBuffer[0], m_deltaArray, std::minus<ElemType>());
|
||||
std::transform(m_cpuAsyncBuffer[0], m_cpuAsyncBuffer[0] + m_totalModelSize, m_deltaArray, m_deltaArray, std::minus<ElemType>());
|
||||
|
||||
// lr decay
|
||||
std::transform(m_deltaArray, m_deltaArray + m_totalModelSize, m_deltaArray, std::bind1st(std::multiplies<ElemType>(), factor));
|
||||
|
@ -303,8 +301,6 @@ namespace Microsoft {
|
|||
multiversoMatrix->Add(px, m_tableLength[widx]);
|
||||
multiversoMatrix->Get(py, m_tableLength[widx]);
|
||||
}
|
||||
//m_matrixArray->Add(m_deltaArray, m_totalModelSize);
|
||||
//m_matrixArray->Get(m_cpuAsyncBuffer[0], m_totalModelSize);
|
||||
|
||||
i = 0;
|
||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, i++)
|
||||
|
@ -349,6 +345,7 @@ namespace Microsoft {
|
|||
m_isInitialized = true;
|
||||
|
||||
multiverso::MV_Init();
|
||||
multiverso::SetCMDFlag<std::string>(std::string("updater_type"), std::string("sgd"));
|
||||
|
||||
m_matrixArray = new std::vector< multiverso::MatrixWorkerTable<ElemType>*>();
|
||||
m_serverArray = new std::vector< multiverso::MatrixServerTable<ElemType>*>();
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 6f6cf3d44407cfdb15a4152b7d73faef7eafea29
|
||||
Subproject commit f9217b4909609425db9bb5a5889b20357fc14fce
|
Загрузка…
Ссылка в новой задаче