Initial implementation of RMSProp for dense matrices. Does not work for

sparse matrices.
This commit is contained in:
Jasha Droppo 2014-09-14 19:20:31 -07:00
Родитель 87657c403d
Коммит 2dc6540f21
10 изменённых файлов: 261 добавлений и 73 удалений

Просмотреть файл

@ -44,6 +44,23 @@ namespace Microsoft { namespace MSR { namespace CNTK {
RmsProp
};
// configuration parameters associated with RMSProp learning algorithm
typedef struct stRMSPropInfo{
double gamma;
double inc;
double dec;
double max;
double min;
stRMSPropInfo()
{
gamma = 0.99;
inc = 1.2;
dec = 0.75;
max = 10.0;
min = 0.1;
}
}RMSPropInfo;
typedef struct stGradientUpdateInfo{
GradientsUpdateType mType;
float mGaussianNoiseInjectStd;
@ -123,6 +140,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
gUpdateInfo.mType = gradUpdateType;
gUpdateInfo.mGaussianNoiseInjectStd = (float)gaussianNoiseInjecStd;
// extract RMSProp parameters from config, if they exist. Default to reasonable values.
RMSPropInfo rpi;
rpi.dec = (double) configSGD("rms_wgt_dec","0.75");
rpi.inc = (double) configSGD("rms_wgt_inc","1.2");
rpi.min = (double) configSGD("rms_wgt_min","0.1");
rpi.max = (double) configSGD("rms_wgt_max","10.0");
rpi.gamma = (double) configSGD("rms_gamma","0.99");
/// for backward support. future setup should use gradUpdateType=AdaGrad, instead of
/// useAdagrad=true
bool useAdagrad = configSGD("useAdagrad", "false");
@ -146,7 +171,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
reduceLearnRateIfImproveLessThan, continueReduce, learnRateDecreaseFactor, dropoutRates,
loadBestModel, numMiniBatch4LRSearch, numPrevLearnRates, numBestSearchEpoch, (UINT16)traceLevel, numMBsToShowResult,
maxTempMemSizeInSamplesForCNN, gUpdateInfo, usePtask, keepCheckPointFiles, adaptationRegType, adaptationRegWeight,
trainCriterionNodeName, evalCriterionNodeName, doGradientCheck, gradientCheckSigDigit, validateAfterModelReloading);
trainCriterionNodeName, evalCriterionNodeName, doGradientCheck, gradientCheckSigDigit, validateAfterModelReloading,
rpi);
}
void setMomentum(float momentum)
@ -167,7 +193,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
const size_t numMBsToShowResult = 10, const size_t maxTempMemSizeInSamplesForCNN = 0,
const GradientUpdateInfo gradUpdateType = GradientUpdateInfo(), const bool usePtask = false, const bool keepCheckPointFiles=false, const AdaptationRegType adaptationRegType = AdaptationRegType::None,
const ElemType adaptationRegWeight = 0.0f, const wstring trainCriterionNodeName= L"", const wstring evalCriterionNodeName=L"",
const bool doGradientCheck = false, const ElemType gradientCheckSigDigit = 6, const bool validateAfterModelReloading = true)
const bool doGradientCheck = false, const ElemType gradientCheckSigDigit = 6, const bool validateAfterModelReloading = true,
RMSPropInfo rpi = RMSPropInfo())
{
numPrevLearnRates;
m_mbSize=mbSize;
@ -195,6 +222,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_numBestSearchEpoch=numBestSearchEpoch;
m_maxTempMemSizeInSamplesForCNN=maxTempMemSizeInSamplesForCNN;
m_gradType = gradUpdateType;
m_rpi = rpi;
m_usePtask = usePtask;
m_keepCheckPointFiles = keepCheckPointFiles;
@ -1096,7 +1124,9 @@ public:
}
if (adpType == GradientsUpdateType::RmsProp)
{
smoothedGradient.RmsProp(gradientValues);
// include L2 regularizer
Matrix<ElemType>::ScaleAndAdd(0.001,functionValues,gradientValues);
smoothedGradient.RmsProp(gradientValues,sgd->m_rpi.gamma,sgd->m_rpi.inc,sgd->m_rpi.max,sgd->m_rpi.dec,sgd->m_rpi.min);
Matrix<ElemType>::ScaleAndAdd(-learnRatePerSample, gradientValues, functionValues);
}
@ -1423,6 +1453,8 @@ protected:
ElemType m_minLearnRate;
GradientUpdateInfo m_gradType;
RMSPropInfo m_rpi;
bool m_usePtask;
bool m_keepCheckPointFiles;

Просмотреть файл

@ -865,43 +865,89 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
template<class ElemType>
void CPUMatrix<ElemType>::RmsProp(CPUMatrix<ElemType>& gradients)
void CPUMatrix<ElemType>::RmsProp(CPUMatrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN
)
{
if (this->IsEmpty())
const ElemType floor = 1e-6f;
size_t n = gradients.GetNumElements();
ElemType *curr_grad=gradients.m_pArray;
if (this->IsEmpty() || this->GetNumCols() < gradients.GetNumCols() * 3)
{
this->Resize(gradients.GetNumRows(), gradients.GetNumCols());
this->Resize(gradients.GetNumRows(), gradients.GetNumCols() * 3);
this->SetValue(0.0);
ElemType *avars=m_pArray; // accumulated variances for RMS scaling
ElemType *steps=m_pArray+2*n; // current step size
// initialize moving average of gradient-squared
for( long i = 0; i < n; i++ )
avars[i] = curr_grad[i]*curr_grad[i];
// initialize starting step size
for( long i = 0; i < n; i++ )
steps[i] = ElemType(0.02);
}
assert(this->GetNumRows() == gradients.GetNumRows() && this->GetNumCols() == gradients.GetNumCols());
ElemType *avars=m_pArray; // accumulated variances for RMS scaling
ElemType *signs=m_pArray+n; // sign of previous gradient
ElemType *steps=m_pArray+2*n; // current step size
ElemType *a=m_pArray, *d_v=gradients.m_pArray;
size_t n = GetNumElements();
long nLoop = (long)n - n%4;
assert(this->GetNumRows() == gradients.GetNumRows() && this->GetNumCols() == gradients.GetNumCols() * 3);
const ElemType floor = 1e-16f;
ElemType ONE_MINUS_GAMMA = ElemType(1.0) - RMS_GAMMA;
//int upd[] = {
// 2,2,0,
// 2,2,0,
// 1,1,1,
// 2,2,0,
// 1,2,1,
// 0,2,2,
// 1,1,1,
// 0,2,2,
// 0,2,2,
//};
#pragma omp parallel for
for (long i=0; i<nLoop; i+=4)
// for (long i=0; i<n; i++)
// {
// avars[i] = RMS_GAMMA * avars[i] + ONE_MINUS_GAMMA * (curr_grad[i] * curr_grad[i]);
// // grad sign base 3: 0->neg, 1->zero, 2->pos
// const int grad_sign = 1 + (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0));
// // signs[i] contains three consecutive grad_sign
// signs[i] = 3*(int(signs[i]) % 9) + grad_sign;
// switch(upd[int(signs[i])])
// {
// case 0:
// steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN);
// break;
// case 2:
// steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX);
// break;
// }
// curr_grad[i] *= steps[i] / sqrt(avars[i] + floor);
// }
for (long i=0; i<n; i++)
{
a[i] = a[i] * ElemType(0.9) + ElemType(0.1) * d_v[i] * d_v[i];
a[i+1] = a[i+1] * ElemType(0.9) + ElemType(0.1) * d_v[i+1] * d_v[i+1];
a[i+2] = a[i+2] * ElemType(0.9) + ElemType(0.1) * d_v[i+2] * d_v[i+2];
a[i+3] = a[i+3] * ElemType(0.9) + ElemType(0.1) * d_v[i+3] * d_v[i+3];
avars[i] = RMS_GAMMA * avars[i] + ONE_MINUS_GAMMA * (curr_grad[i] * curr_grad[i]);
const int grad_sign = (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0));
d_v[i] /= (sqrt(a[i]) + floor);
d_v[i+1] /= (sqrt(a[i+1]) + floor);
d_v[i+2] /= (sqrt(a[i+2]) + floor);
d_v[i+3] /= (sqrt(a[i+3]) + floor);
if( signs[i] * grad_sign > 0 )
steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX);
else
steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN);
curr_grad[i] *= steps[i] / sqrt(avars[i] + floor);
signs[i] = grad_sign;
}
for (long i=nLoop; i<n; i++)
{
a[i] = a[i] * ElemType(0.9) + ElemType(0.1) * d_v[i] * d_v[i];
d_v[i] /= (sqrt(a[i]) + floor);
}
}
template<class ElemType>

Просмотреть файл

@ -58,7 +58,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
CPUMatrix<ElemType>& AssignColumnSlice(const CPUMatrix<ElemType>& fromMatrix, size_t startColumn, size_t numCols);
void Adagrad(CPUMatrix<ElemType>& gradients);
void RmsProp(CPUMatrix<ElemType>& gradients);
void RmsProp(CPUMatrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN
);
void Reshape(const size_t numRows, const size_t numCols);
void Resize(const size_t numRows, const size_t numCols, bool growOnly = true); //by default we only reallocate if need to grow

Просмотреть файл

@ -760,42 +760,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
template<class ElemType>
void CPUSparseMatrix<ElemType>::RmsProp(CPUMatrix<ElemType>& c)
{
if (c.IsEmpty())
{
c.Resize(this->GetNumRows(), this->GetNumCols());
c.SetValue(0.0);
}
if(c.GetFormat() == MatrixFormat::matrixFormatSparseCSC)
{
const ElemType floor = 1e-16f;
for(size_t j = 0; j < GetNumCols(); j++)
{
size_t start = m_pb[j];
size_t end = m_pb[j+1];
for(size_t p = start; p < end; p++)
{
size_t i = m_row[p];
ElemType val = m_val[p];
ElemType adenorm = c(i, j);
adenorm = adenorm * (ElemType)0.9 + (ElemType)0.1 * val * val;
val = val / (floor + sqrt(adenorm));
m_val[p] = val;
c(i, j) = adenorm;
}
}
}
else
{
throw std::exception("CPUSparseMatrix:: RmsProp() only support CSC");
}
}
template<class ElemType>
CPUSparseMatrix<ElemType>& CPUSparseMatrix<ElemType>::InplaceTruncate (const ElemType threshold)
{

Просмотреть файл

@ -85,7 +85,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
public:
void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum);
void Adagrad(CPUMatrix<ElemType>& c);
void RmsProp(CPUMatrix<ElemType>& c);
public:
CPUSparseMatrix<ElemType>& InplaceTruncateTop (const ElemType /*threshold*/) { NOT_IMPLEMENTED; }

Просмотреть файл

@ -966,6 +966,63 @@ namespace Microsoft { namespace MSR { namespace CNTK {
_adagrad<ElemType><<<blocksPerGrid, threadsPerBlock>>>(m_pArray, gradients.m_pArray, GetNumElements());
}
template<class ElemType>
void GPUMatrix<ElemType>::RmsProp(GPUMatrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN
)
{
const ElemType floor = 1e-6f;
static ElemType *upd_gpu = (ElemType*)0;
size_t n = gradients.GetNumElements();
int blocksPerGrid = (GetNumElements() + threadsPerBlock -1 )/threadsPerBlock;
if (this->IsEmpty() || this->GetNumCols() < gradients.GetNumCols() * 3)
{
this->Resize(gradients.GetNumRows(), gradients.GetNumCols() * 3);
this->SetValue(0.0);
ElemType *avars=m_pArray; // accumulated variances for RMS scaling
ElemType *signs=m_pArray+n; // sign of previous gradient
ElemType *steps=m_pArray+2*n; // current step size
_rmsprop_init<ElemType><<<blocksPerGrid, threadsPerBlock>>>(avars,signs,steps,gradients.m_pArray,n);
}
ElemType *avars=m_pArray; // accumulated variances for RMS scaling
ElemType *signs=m_pArray+n; // sign of previous gradient
ElemType *steps=m_pArray+2*n; // current step size
assert(this->GetNumRows() == gradients.GetNumRows() && this->GetNumCols() == gradients.GetNumCols() * 3);
if( !upd_gpu )
{
ElemType upd[] = {
2,2,0,
2,2,0,
1,1,1,
2,2,0,
1,2,1,
0,2,2,
1,1,1,
0,2,2,
0,2,2,
};
CUDA_CALL(cudaMalloc((void**)&upd_gpu,sizeof(ElemType)*27));
CUDA_CALL(cudaMemcpy(upd_gpu,upd,sizeof(ElemType)*27,cudaMemcpyHostToDevice));
}
_rmsprop<ElemType><<<blocksPerGrid, threadsPerBlock>>>(avars,signs,steps,gradients.m_pArray,n,
RMS_GAMMA,RMS_WGT_INC,RMS_WGT_MAX,RMS_WGT_DEC,RMS_WGT_MIN,
floor,upd_gpu);
}
template<class ElemType>
void GPUMatrix<ElemType>::Reshape(const size_t numRows, const size_t numCols)
{

Просмотреть файл

@ -99,7 +99,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
ElemType* BufferPointer() const {return m_pArray;}
void Adagrad(GPUMatrix<ElemType>& gradients);
void RmsProp(GPUMatrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN
);
void Reshape(const size_t numRows, const size_t numCols);
void Resize(const size_t numRows, const size_t numCols, bool growOnly = true); //by default we only reallocate if need to grow

Просмотреть файл

@ -972,6 +972,72 @@ __global__ void _adagrad(
d_v[id] /= sqrt(a[id]+floor);
}
template<class ElemType>
__global__ void _rmsprop_init(
ElemType* avars, ElemType* signs, ElemType* steps,
ElemType* curr_grad,
const LONG64 N
)
{
LONG64 i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= N)
return;
ElemType tmp = curr_grad[i];
avars[i] = tmp * tmp;
signs[i] = ElemType(0.0);
steps[i] = ElemType(0.02);
}
template<class ElemType>
__global__ void _rmsprop(
ElemType* avars, ElemType* signs, ElemType* steps,
ElemType* curr_grad,
const LONG64 N,
ElemType RMS_GAMMA,ElemType RMS_WGT_INC,ElemType RMS_WGT_MAX,ElemType RMS_WGT_DEC,ElemType RMS_WGT_MIN,
ElemType floor,
ElemType *upd_gpu
)
{
LONG64 i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= N)
return;
avars[i] = RMS_GAMMA * avars[i] + (ElemType(1.0)-RMS_GAMMA)* (curr_grad[i] * curr_grad[i]);
//// grad sign base 3: 0->neg, 1->zero, 2->pos
//const int grad_sign = 1 + (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0));
//// signs[i] contains three consecutive grad_sign
//signs[i] = 3*(int(signs[i]) % 9) + grad_sign;
//// update according to the following table:
//// (!pos,!pos,!pos) or (!neg,!neg,!neg): RMS_WGT_INC
//// (!neg,!neg,neg) or (!pos,!pos,pos): RMS_WGT_DEC
//// otherwise: no action
//switch(int(upd_gpu[int(signs[i])]))
//{
//case 0:
// steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN);
// break;
//case 2:
// steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX);
// break;
//}
//curr_grad[i] *= steps[i] / sqrt(avars[i] + floor);
const int grad_sign = (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0));
if( signs[i] * grad_sign > 0 )
steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX);
else
steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN);
curr_grad[i] *= steps[i] / sqrt(avars[i] + floor);
signs[i] = grad_sign;
}
template<class ElemType>
__global__ void _rescaleToRange(

Просмотреть файл

@ -1098,15 +1098,21 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
template<class ElemType>
void Matrix<ElemType>::RmsProp(Matrix<ElemType>& gradients)
void Matrix<ElemType>::RmsProp(Matrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN
)
{
DecideAndMoveToRightDevice(*this, gradients);
DISPATCH_MATRIX_ON_FLAG(this,
&gradients,
m_CPUMatrix->RmsProp(*gradients.m_CPUMatrix); SetDataLocation(CPU),
m_CPUMatrix->RmsProp(*gradients.m_CPUMatrix, RMS_GAMMA, RMS_WGT_INC, RMS_WGT_MAX, RMS_WGT_DEC, RMS_WGT_MIN); SetDataLocation(CPU),
m_GPUMatrix->RmsProp(*gradients.m_GPUMatrix, RMS_GAMMA, RMS_WGT_INC, RMS_WGT_MAX, RMS_WGT_DEC, RMS_WGT_MIN); SetDataLocation(GPU),
NOT_IMPLEMENTED,
m_CPUSparseMatrix->RmsProp(*this->m_CPUMatrix); SetDataLocation(CPU),
NOT_IMPLEMENTED
);
}

Просмотреть файл

@ -107,7 +107,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
void NormalGrad(Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, const ElemType learnRatePerSample, const ElemType momentum);
void Adagrad(Matrix<ElemType>& gradients);
void RmsProp(Matrix<ElemType>& gradients);
void RmsProp(Matrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN
);
void Reshape(const size_t numRows, const size_t numCols);
void Resize(const size_t numRows, const size_t numCols, bool growOnly = true); //by default we only reallocate if need to grow