diff --git a/Source/Math/GPUMatrix.cu b/Source/Math/GPUMatrix.cu index f2c5c73ca..0e7ada4d5 100755 --- a/Source/Math/GPUMatrix.cu +++ b/Source/Math/GPUMatrix.cu @@ -1537,10 +1537,10 @@ void GPUMatrix::FSAdagrad(GPUMatrix& gradients, template void GPUMatrix::Adam(GPUMatrix& gradients, GPUMatrix& functionValues, - ElemType learnRatePerSample, - ElemType momentum, - ElemType adaWeight, - ElemType adaMul, + ElemType learnRatePerSample, //alpha + ElemType momentum, // /beta_1 + ElemType adaWeight, // /beta_2 + ElemType adaMul, //biasCorrection ElemType epsilon, ElemType unitGainFactor, bool adamax) diff --git a/Source/Math/Matrix.cpp b/Source/Math/Matrix.cpp index 7a26dcc37..a377eda81 100755 --- a/Source/Math/Matrix.cpp +++ b/Source/Math/Matrix.cpp @@ -1827,7 +1827,11 @@ void Matrix::FSAdagradUpdate(Matrix& gradients, Matrix void Matrix::AdamUpdate(Matrix& gradients, Matrix& functionValues, const double smoothedCount, const double learnRatePerSample, const double meanMomentum, const double varMomentum, const double epsilon, ElemType unitGainFactor, bool adamax) diff --git a/Source/SGDLib/SGD.cpp b/Source/SGDLib/SGD.cpp index 35a14906d..a924a9256 100644 --- a/Source/SGDLib/SGD.cpp +++ b/Source/SGDLib/SGD.cpp @@ -2429,6 +2429,11 @@ void SGD::UpdateWeights(Matrix& functionValues, Matrix::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues); } + else if (adpType == GradientsUpdateType::Adam) + { + smoothedGradientValues.AdamUpdate(gradientValues, functionValues, smoothedCount + 1, learnRatePerSample, + m_adam.meanMomentum, m_adam.varMomentum, m_adam.epsilon, (ElemType)(1 - m_adam.meanMomentum), false); + } if (noiseStd > 0) { @@ -2840,6 +2845,7 @@ static GradientsUpdateType ParseGradUpdateType(const wstring& s) else if (EqualCI(s, L"adagrad")) return GradientsUpdateType::AdaGrad; else if (EqualCI(s, L"rmsProp")) return GradientsUpdateType::RmsProp; else if (EqualCI(s, L"fsAdagrad")) return GradientsUpdateType::FSAdaGrad; + else if (EqualCI(s, L"adam")) return GradientsUpdateType::Adam; // legacy, deprecated else if (EqualCI(s, L"normal") || EqualCI(s, L"simple")) return GradientsUpdateType::None; else InvalidArgument("ParseGradUpdateType: Invalid Gradient Updating Type. Valid values are (none | adagrad | rmsProp | fsAdagrad )"); @@ -3006,6 +3012,11 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType) m_rpi.max = configSGD(L"rms_wgt_max", 10.0); m_rpi.gamma = configSGD(L"rms_gamma", 0.99); + // Adam settings + m_adam.meanMomentum = configSGD(L"adam_meanMomentum", 0.9); + m_adam.varMomentum = configSGD(L"adam_varMomentum", 0.999); + m_adam.epsilon = configSGD(L"adam_epsilon", pow(10, -8)); + m_needAveMultiplier = configSGD(L"normWithAveMultiplier", true); m_L2RegWeight = configSGD(L"L2RegWeight", 0.0); m_L1RegWeight = configSGD(L"L1RegWeight", 0.0); diff --git a/Source/SGDLib/SGD.h b/Source/SGDLib/SGD.h index 9a71d1e54..04eb55073 100644 --- a/Source/SGDLib/SGD.h +++ b/Source/SGDLib/SGD.h @@ -55,7 +55,8 @@ enum class GradientsUpdateType : int None, AdaGrad, RmsProp, - FSAdaGrad + FSAdaGrad, + Adam }; // modelParallelSGD can be combined with dataParallelSGD/modelAveragingSGD/blockMomentumSGD @@ -91,6 +92,21 @@ struct RMSPropInfo } }; +struct AdamInfo +{ + double meanMomentum; //beta_1 + double varMomentum; //beta_2 + double epsilon; + + AdamInfo() + { + meanMomentum = 0.9; + varMomentum = 0.999; + epsilon = pow(10, -8); + } +}; + + struct GradientUpdateInfo { GradientsUpdateType type = GradientsUpdateType::AdaGrad; @@ -254,6 +270,7 @@ protected: GradientUpdateInfo m_gradType; RMSPropInfo m_rpi; + AdamInfo m_adam; size_t m_numMBsToShowResult = 0; size_t m_firstMBsToShowResult = 0;