From 2dc6540f217881ff86b8a923a0dc689bbb2d76df Mon Sep 17 00:00:00 2001 From: Jasha Droppo Date: Sun, 14 Sep 2014 19:20:31 -0700 Subject: [PATCH] Initial implementation of RMSProp for dense matrices. Does not work for sparse matrices. --- MachineLearning/cn/SGD.h | 40 +++++++++++-- Math/Math/CPUMatrix.cpp | 98 +++++++++++++++++++++++-------- Math/Math/CPUMatrix.h | 8 ++- Math/Math/CPUSparseMatrix.cpp | 36 ------------ Math/Math/CPUSparseMatrix.h | 1 - Math/Math/GPUMatrix.cu | 57 ++++++++++++++++++ Math/Math/GPUMatrix.cuh | 8 ++- Math/Math/GPUMatrixCUDAKernels.cu | 66 +++++++++++++++++++++ Math/Math/Matrix.cpp | 12 +++- Math/Math/Matrix.h | 8 ++- 10 files changed, 261 insertions(+), 73 deletions(-) diff --git a/MachineLearning/cn/SGD.h b/MachineLearning/cn/SGD.h index eaccc13fc..15490b8f0 100644 --- a/MachineLearning/cn/SGD.h +++ b/MachineLearning/cn/SGD.h @@ -44,7 +44,24 @@ namespace Microsoft { namespace MSR { namespace CNTK { RmsProp }; - typedef struct stGradientUpdateInfo{ + // configuration parameters associated with RMSProp learning algorithm + typedef struct stRMSPropInfo{ + double gamma; + double inc; + double dec; + double max; + double min; + stRMSPropInfo() + { + gamma = 0.99; + inc = 1.2; + dec = 0.75; + max = 10.0; + min = 0.1; + } + }RMSPropInfo; + + typedef struct stGradientUpdateInfo{ GradientsUpdateType mType; float mGaussianNoiseInjectStd; stGradientUpdateInfo() @@ -123,6 +140,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { gUpdateInfo.mType = gradUpdateType; gUpdateInfo.mGaussianNoiseInjectStd = (float)gaussianNoiseInjecStd; + // extract RMSProp parameters from config, if they exist. Default to reasonable values. + RMSPropInfo rpi; + rpi.dec = (double) configSGD("rms_wgt_dec","0.75"); + rpi.inc = (double) configSGD("rms_wgt_inc","1.2"); + rpi.min = (double) configSGD("rms_wgt_min","0.1"); + rpi.max = (double) configSGD("rms_wgt_max","10.0"); + rpi.gamma = (double) configSGD("rms_gamma","0.99"); + /// for backward support. future setup should use gradUpdateType=AdaGrad, instead of /// useAdagrad=true bool useAdagrad = configSGD("useAdagrad", "false"); @@ -146,7 +171,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { reduceLearnRateIfImproveLessThan, continueReduce, learnRateDecreaseFactor, dropoutRates, loadBestModel, numMiniBatch4LRSearch, numPrevLearnRates, numBestSearchEpoch, (UINT16)traceLevel, numMBsToShowResult, maxTempMemSizeInSamplesForCNN, gUpdateInfo, usePtask, keepCheckPointFiles, adaptationRegType, adaptationRegWeight, - trainCriterionNodeName, evalCriterionNodeName, doGradientCheck, gradientCheckSigDigit, validateAfterModelReloading); + trainCriterionNodeName, evalCriterionNodeName, doGradientCheck, gradientCheckSigDigit, validateAfterModelReloading, + rpi); } void setMomentum(float momentum) @@ -167,7 +193,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { const size_t numMBsToShowResult = 10, const size_t maxTempMemSizeInSamplesForCNN = 0, const GradientUpdateInfo gradUpdateType = GradientUpdateInfo(), const bool usePtask = false, const bool keepCheckPointFiles=false, const AdaptationRegType adaptationRegType = AdaptationRegType::None, const ElemType adaptationRegWeight = 0.0f, const wstring trainCriterionNodeName= L"", const wstring evalCriterionNodeName=L"", - const bool doGradientCheck = false, const ElemType gradientCheckSigDigit = 6, const bool validateAfterModelReloading = true) + const bool doGradientCheck = false, const ElemType gradientCheckSigDigit = 6, const bool validateAfterModelReloading = true, + RMSPropInfo rpi = RMSPropInfo()) { numPrevLearnRates; m_mbSize=mbSize; @@ -195,6 +222,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_numBestSearchEpoch=numBestSearchEpoch; m_maxTempMemSizeInSamplesForCNN=maxTempMemSizeInSamplesForCNN; m_gradType = gradUpdateType; + m_rpi = rpi; m_usePtask = usePtask; m_keepCheckPointFiles = keepCheckPointFiles; @@ -1096,7 +1124,9 @@ public: } if (adpType == GradientsUpdateType::RmsProp) { - smoothedGradient.RmsProp(gradientValues); + // include L2 regularizer + Matrix::ScaleAndAdd(0.001,functionValues,gradientValues); + smoothedGradient.RmsProp(gradientValues,sgd->m_rpi.gamma,sgd->m_rpi.inc,sgd->m_rpi.max,sgd->m_rpi.dec,sgd->m_rpi.min); Matrix::ScaleAndAdd(-learnRatePerSample, gradientValues, functionValues); } @@ -1423,6 +1453,8 @@ protected: ElemType m_minLearnRate; GradientUpdateInfo m_gradType; + RMSPropInfo m_rpi; + bool m_usePtask; bool m_keepCheckPointFiles; diff --git a/Math/Math/CPUMatrix.cpp b/Math/Math/CPUMatrix.cpp index 7a5e97b47..c85b19d56 100644 --- a/Math/Math/CPUMatrix.cpp +++ b/Math/Math/CPUMatrix.cpp @@ -865,43 +865,89 @@ namespace Microsoft { namespace MSR { namespace CNTK { } template - void CPUMatrix::RmsProp(CPUMatrix& gradients) + void CPUMatrix::RmsProp(CPUMatrix& gradients, + ElemType RMS_GAMMA, + ElemType RMS_WGT_INC, + ElemType RMS_WGT_MAX, + ElemType RMS_WGT_DEC, + ElemType RMS_WGT_MIN + ) { - if (this->IsEmpty()) + const ElemType floor = 1e-6f; + + size_t n = gradients.GetNumElements(); + ElemType *curr_grad=gradients.m_pArray; + + if (this->IsEmpty() || this->GetNumCols() < gradients.GetNumCols() * 3) { - this->Resize(gradients.GetNumRows(), gradients.GetNumCols()); + this->Resize(gradients.GetNumRows(), gradients.GetNumCols() * 3); this->SetValue(0.0); + + ElemType *avars=m_pArray; // accumulated variances for RMS scaling + ElemType *steps=m_pArray+2*n; // current step size + + // initialize moving average of gradient-squared + for( long i = 0; i < n; i++ ) + avars[i] = curr_grad[i]*curr_grad[i]; + + // initialize starting step size + for( long i = 0; i < n; i++ ) + steps[i] = ElemType(0.02); } - assert(this->GetNumRows() == gradients.GetNumRows() && this->GetNumCols() == gradients.GetNumCols()); + ElemType *avars=m_pArray; // accumulated variances for RMS scaling + ElemType *signs=m_pArray+n; // sign of previous gradient + ElemType *steps=m_pArray+2*n; // current step size - ElemType *a=m_pArray, *d_v=gradients.m_pArray; - size_t n = GetNumElements(); - long nLoop = (long)n - n%4; + assert(this->GetNumRows() == gradients.GetNumRows() && this->GetNumCols() == gradients.GetNumCols() * 3); - const ElemType floor = 1e-16f; + ElemType ONE_MINUS_GAMMA = ElemType(1.0) - RMS_GAMMA; + //int upd[] = { + // 2,2,0, + // 2,2,0, + // 1,1,1, + // 2,2,0, + // 1,2,1, + // 0,2,2, + // 1,1,1, + // 0,2,2, + // 0,2,2, + //}; -#pragma omp parallel for - for (long i=0; ineg, 1->zero, 2->pos + // const int grad_sign = 1 + (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0)); + + // // signs[i] contains three consecutive grad_sign + // signs[i] = 3*(int(signs[i]) % 9) + grad_sign; + + // switch(upd[int(signs[i])]) + // { + // case 0: + // steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN); + // break; + // case 2: + // steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX); + // break; + // } + // curr_grad[i] *= steps[i] / sqrt(avars[i] + floor); + // } + + for (long i=0; i 0 ) + steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX); + else + steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN); + + curr_grad[i] *= steps[i] / sqrt(avars[i] + floor); + signs[i] = grad_sign; } - - for (long i=nLoop; i diff --git a/Math/Math/CPUMatrix.h b/Math/Math/CPUMatrix.h index c03603d35..44178b81f 100644 --- a/Math/Math/CPUMatrix.h +++ b/Math/Math/CPUMatrix.h @@ -58,7 +58,13 @@ namespace Microsoft { namespace MSR { namespace CNTK { CPUMatrix& AssignColumnSlice(const CPUMatrix& fromMatrix, size_t startColumn, size_t numCols); void Adagrad(CPUMatrix& gradients); - void RmsProp(CPUMatrix& gradients); + void RmsProp(CPUMatrix& gradients, + ElemType RMS_GAMMA, + ElemType RMS_WGT_INC, + ElemType RMS_WGT_MAX, + ElemType RMS_WGT_DEC, + ElemType RMS_WGT_MIN + ); void Reshape(const size_t numRows, const size_t numCols); void Resize(const size_t numRows, const size_t numCols, bool growOnly = true); //by default we only reallocate if need to grow diff --git a/Math/Math/CPUSparseMatrix.cpp b/Math/Math/CPUSparseMatrix.cpp index 4ae8081e9..fc421c138 100644 --- a/Math/Math/CPUSparseMatrix.cpp +++ b/Math/Math/CPUSparseMatrix.cpp @@ -760,42 +760,6 @@ namespace Microsoft { namespace MSR { namespace CNTK { } } - - template - void CPUSparseMatrix::RmsProp(CPUMatrix& c) - { - if (c.IsEmpty()) - { - c.Resize(this->GetNumRows(), this->GetNumCols()); - c.SetValue(0.0); - } - - if(c.GetFormat() == MatrixFormat::matrixFormatSparseCSC) - { - const ElemType floor = 1e-16f; - for(size_t j = 0; j < GetNumCols(); j++) - { - size_t start = m_pb[j]; - size_t end = m_pb[j+1]; - for(size_t p = start; p < end; p++) - { - size_t i = m_row[p]; - ElemType val = m_val[p]; - - ElemType adenorm = c(i, j); - adenorm = adenorm * (ElemType)0.9 + (ElemType)0.1 * val * val; - val = val / (floor + sqrt(adenorm)); - m_val[p] = val; - c(i, j) = adenorm; - } - } - } - else - { - throw std::exception("CPUSparseMatrix:: RmsProp() only support CSC"); - } - } - template CPUSparseMatrix& CPUSparseMatrix::InplaceTruncate (const ElemType threshold) { diff --git a/Math/Math/CPUSparseMatrix.h b/Math/Math/CPUSparseMatrix.h index 57ebb8af2..a73fb05b9 100644 --- a/Math/Math/CPUSparseMatrix.h +++ b/Math/Math/CPUSparseMatrix.h @@ -85,7 +85,6 @@ namespace Microsoft { namespace MSR { namespace CNTK { public: void NormalGrad(CPUMatrix& c, const ElemType momentum); void Adagrad(CPUMatrix& c); - void RmsProp(CPUMatrix& c); public: CPUSparseMatrix& InplaceTruncateTop (const ElemType /*threshold*/) { NOT_IMPLEMENTED; } diff --git a/Math/Math/GPUMatrix.cu b/Math/Math/GPUMatrix.cu index 701290eaf..1088c8493 100644 --- a/Math/Math/GPUMatrix.cu +++ b/Math/Math/GPUMatrix.cu @@ -966,6 +966,63 @@ namespace Microsoft { namespace MSR { namespace CNTK { _adagrad<<>>(m_pArray, gradients.m_pArray, GetNumElements()); } + template + void GPUMatrix::RmsProp(GPUMatrix& gradients, + ElemType RMS_GAMMA, + ElemType RMS_WGT_INC, + ElemType RMS_WGT_MAX, + ElemType RMS_WGT_DEC, + ElemType RMS_WGT_MIN + ) + { + const ElemType floor = 1e-6f; + static ElemType *upd_gpu = (ElemType*)0; + + size_t n = gradients.GetNumElements(); + int blocksPerGrid = (GetNumElements() + threadsPerBlock -1 )/threadsPerBlock; + + if (this->IsEmpty() || this->GetNumCols() < gradients.GetNumCols() * 3) + { + this->Resize(gradients.GetNumRows(), gradients.GetNumCols() * 3); + this->SetValue(0.0); + + ElemType *avars=m_pArray; // accumulated variances for RMS scaling + ElemType *signs=m_pArray+n; // sign of previous gradient + ElemType *steps=m_pArray+2*n; // current step size + + _rmsprop_init<<>>(avars,signs,steps,gradients.m_pArray,n); + + } + + ElemType *avars=m_pArray; // accumulated variances for RMS scaling + ElemType *signs=m_pArray+n; // sign of previous gradient + ElemType *steps=m_pArray+2*n; // current step size + + assert(this->GetNumRows() == gradients.GetNumRows() && this->GetNumCols() == gradients.GetNumCols() * 3); + + if( !upd_gpu ) + { + ElemType upd[] = { + 2,2,0, + 2,2,0, + 1,1,1, + 2,2,0, + 1,2,1, + 0,2,2, + 1,1,1, + 0,2,2, + 0,2,2, + }; + + CUDA_CALL(cudaMalloc((void**)&upd_gpu,sizeof(ElemType)*27)); + CUDA_CALL(cudaMemcpy(upd_gpu,upd,sizeof(ElemType)*27,cudaMemcpyHostToDevice)); + } + + _rmsprop<<>>(avars,signs,steps,gradients.m_pArray,n, + RMS_GAMMA,RMS_WGT_INC,RMS_WGT_MAX,RMS_WGT_DEC,RMS_WGT_MIN, + floor,upd_gpu); + } + template void GPUMatrix::Reshape(const size_t numRows, const size_t numCols) { diff --git a/Math/Math/GPUMatrix.cuh b/Math/Math/GPUMatrix.cuh index 6687d0db9..1ae0f773e 100644 --- a/Math/Math/GPUMatrix.cuh +++ b/Math/Math/GPUMatrix.cuh @@ -99,7 +99,13 @@ namespace Microsoft { namespace MSR { namespace CNTK { ElemType* BufferPointer() const {return m_pArray;} void Adagrad(GPUMatrix& gradients); - + void RmsProp(GPUMatrix& gradients, + ElemType RMS_GAMMA, + ElemType RMS_WGT_INC, + ElemType RMS_WGT_MAX, + ElemType RMS_WGT_DEC, + ElemType RMS_WGT_MIN + ); void Reshape(const size_t numRows, const size_t numCols); void Resize(const size_t numRows, const size_t numCols, bool growOnly = true); //by default we only reallocate if need to grow diff --git a/Math/Math/GPUMatrixCUDAKernels.cu b/Math/Math/GPUMatrixCUDAKernels.cu index 98b760af4..8e104d42e 100644 --- a/Math/Math/GPUMatrixCUDAKernels.cu +++ b/Math/Math/GPUMatrixCUDAKernels.cu @@ -972,6 +972,72 @@ __global__ void _adagrad( d_v[id] /= sqrt(a[id]+floor); } +template +__global__ void _rmsprop_init( + ElemType* avars, ElemType* signs, ElemType* steps, + ElemType* curr_grad, + const LONG64 N + ) +{ + LONG64 i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= N) + return; + + ElemType tmp = curr_grad[i]; + avars[i] = tmp * tmp; + signs[i] = ElemType(0.0); + steps[i] = ElemType(0.02); +} + +template +__global__ void _rmsprop( + ElemType* avars, ElemType* signs, ElemType* steps, + ElemType* curr_grad, + const LONG64 N, + ElemType RMS_GAMMA,ElemType RMS_WGT_INC,ElemType RMS_WGT_MAX,ElemType RMS_WGT_DEC,ElemType RMS_WGT_MIN, + ElemType floor, + ElemType *upd_gpu + ) +{ + LONG64 i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= N) + return; + + avars[i] = RMS_GAMMA * avars[i] + (ElemType(1.0)-RMS_GAMMA)* (curr_grad[i] * curr_grad[i]); + + //// grad sign base 3: 0->neg, 1->zero, 2->pos + //const int grad_sign = 1 + (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0)); + + //// signs[i] contains three consecutive grad_sign + //signs[i] = 3*(int(signs[i]) % 9) + grad_sign; + + //// update according to the following table: + //// (!pos,!pos,!pos) or (!neg,!neg,!neg): RMS_WGT_INC + //// (!neg,!neg,neg) or (!pos,!pos,pos): RMS_WGT_DEC + //// otherwise: no action + + //switch(int(upd_gpu[int(signs[i])])) + //{ + //case 0: + // steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN); + // break; + //case 2: + // steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX); + // break; + //} + //curr_grad[i] *= steps[i] / sqrt(avars[i] + floor); + + const int grad_sign = (ElemType(0) < curr_grad[i]) - (curr_grad[i] < ElemType(0)); + + if( signs[i] * grad_sign > 0 ) + steps[i] = min(steps[i] * RMS_WGT_INC, RMS_WGT_MAX); + else + steps[i] = max(steps[i] * RMS_WGT_DEC, RMS_WGT_MIN); + + curr_grad[i] *= steps[i] / sqrt(avars[i] + floor); + signs[i] = grad_sign; + +} template __global__ void _rescaleToRange( diff --git a/Math/Math/Matrix.cpp b/Math/Math/Matrix.cpp index 93d6bbfcc..b5faf7ff5 100644 --- a/Math/Math/Matrix.cpp +++ b/Math/Math/Matrix.cpp @@ -1098,15 +1098,21 @@ namespace Microsoft { namespace MSR { namespace CNTK { } template - void Matrix::RmsProp(Matrix& gradients) + void Matrix::RmsProp(Matrix& gradients, + ElemType RMS_GAMMA, + ElemType RMS_WGT_INC, + ElemType RMS_WGT_MAX, + ElemType RMS_WGT_DEC, + ElemType RMS_WGT_MIN + ) { DecideAndMoveToRightDevice(*this, gradients); DISPATCH_MATRIX_ON_FLAG(this, &gradients, - m_CPUMatrix->RmsProp(*gradients.m_CPUMatrix); SetDataLocation(CPU), + m_CPUMatrix->RmsProp(*gradients.m_CPUMatrix, RMS_GAMMA, RMS_WGT_INC, RMS_WGT_MAX, RMS_WGT_DEC, RMS_WGT_MIN); SetDataLocation(CPU), + m_GPUMatrix->RmsProp(*gradients.m_GPUMatrix, RMS_GAMMA, RMS_WGT_INC, RMS_WGT_MAX, RMS_WGT_DEC, RMS_WGT_MIN); SetDataLocation(GPU), NOT_IMPLEMENTED, - m_CPUSparseMatrix->RmsProp(*this->m_CPUMatrix); SetDataLocation(CPU), NOT_IMPLEMENTED ); } diff --git a/Math/Math/Matrix.h b/Math/Math/Matrix.h index b88b033a9..d901b3853 100644 --- a/Math/Math/Matrix.h +++ b/Math/Math/Matrix.h @@ -107,7 +107,13 @@ namespace Microsoft { namespace MSR { namespace CNTK { void NormalGrad(Matrix& gradients, Matrix& functionValues, const ElemType learnRatePerSample, const ElemType momentum); void Adagrad(Matrix& gradients); - void RmsProp(Matrix& gradients); + void RmsProp(Matrix& gradients, + ElemType RMS_GAMMA, + ElemType RMS_WGT_INC, + ElemType RMS_WGT_MAX, + ElemType RMS_WGT_DEC, + ElemType RMS_WGT_MIN + ); void Reshape(const size_t numRows, const size_t numCols); void Resize(const size_t numRows, const size_t numCols, bool growOnly = true); //by default we only reallocate if need to grow