This commit is contained in:
Vadim Mazalov 2018-12-18 15:36:45 -08:00
Родитель 93e10096cb
Коммит b891ec0759
4 изменённых файлов: 38 добавлений и 6 удалений

Просмотреть файл

@ -1537,10 +1537,10 @@ void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients,
template <class ElemType>
void GPUMatrix<ElemType>::Adam(GPUMatrix<ElemType>& gradients,
GPUMatrix<ElemType>& functionValues,
ElemType learnRatePerSample,
ElemType momentum,
ElemType adaWeight,
ElemType adaMul,
ElemType learnRatePerSample, //alpha
ElemType momentum, // /beta_1
ElemType adaWeight, // /beta_2
ElemType adaMul, //biasCorrection
ElemType epsilon,
ElemType unitGainFactor,
bool adamax)

Просмотреть файл

@ -1827,7 +1827,11 @@ void Matrix<ElemType>::FSAdagradUpdate(Matrix<ElemType>& gradients, Matrix<ElemT
///
// Implement the original adam algorithm according to the paper
// Ref: ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION, https://arxiv.org/pdf/1412.6980.pdf
///
///
// Association between the method parameters and the paper notation:
// smoothedCount - t
// meanMomentum - /beta_1
// varMomentum - /beta_2
template <class ElemType>
void Matrix<ElemType>::AdamUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, const double smoothedCount,
const double learnRatePerSample, const double meanMomentum, const double varMomentum, const double epsilon, ElemType unitGainFactor, bool adamax)

Просмотреть файл

@ -2429,6 +2429,11 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
(ElemType) m_rpi.dec, (ElemType) m_rpi.min, needAveMultiplier, true);
Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues);
}
else if (adpType == GradientsUpdateType::Adam)
{
smoothedGradientValues.AdamUpdate(gradientValues, functionValues, smoothedCount + 1, learnRatePerSample,
m_adam.meanMomentum, m_adam.varMomentum, m_adam.epsilon, (ElemType)(1 - m_adam.meanMomentum), false);
}
if (noiseStd > 0)
{
@ -2840,6 +2845,7 @@ static GradientsUpdateType ParseGradUpdateType(const wstring& s)
else if (EqualCI(s, L"adagrad")) return GradientsUpdateType::AdaGrad;
else if (EqualCI(s, L"rmsProp")) return GradientsUpdateType::RmsProp;
else if (EqualCI(s, L"fsAdagrad")) return GradientsUpdateType::FSAdaGrad;
else if (EqualCI(s, L"adam")) return GradientsUpdateType::Adam;
// legacy, deprecated
else if (EqualCI(s, L"normal") || EqualCI(s, L"simple")) return GradientsUpdateType::None;
else InvalidArgument("ParseGradUpdateType: Invalid Gradient Updating Type. Valid values are (none | adagrad | rmsProp | fsAdagrad )");
@ -3006,6 +3012,11 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_rpi.max = configSGD(L"rms_wgt_max", 10.0);
m_rpi.gamma = configSGD(L"rms_gamma", 0.99);
// Adam settings
m_adam.meanMomentum = configSGD(L"adam_meanMomentum", 0.9);
m_adam.varMomentum = configSGD(L"adam_varMomentum", 0.999);
m_adam.epsilon = configSGD(L"adam_epsilon", pow(10, -8));
m_needAveMultiplier = configSGD(L"normWithAveMultiplier", true);
m_L2RegWeight = configSGD(L"L2RegWeight", 0.0);
m_L1RegWeight = configSGD(L"L1RegWeight", 0.0);

Просмотреть файл

@ -55,7 +55,8 @@ enum class GradientsUpdateType : int
None,
AdaGrad,
RmsProp,
FSAdaGrad
FSAdaGrad,
Adam
};
// modelParallelSGD can be combined with dataParallelSGD/modelAveragingSGD/blockMomentumSGD
@ -91,6 +92,21 @@ struct RMSPropInfo
}
};
struct AdamInfo
{
double meanMomentum; //beta_1
double varMomentum; //beta_2
double epsilon;
AdamInfo()
{
meanMomentum = 0.9;
varMomentum = 0.999;
epsilon = pow(10, -8);
}
};
struct GradientUpdateInfo
{
GradientsUpdateType type = GradientsUpdateType::AdaGrad;
@ -254,6 +270,7 @@ protected:
GradientUpdateInfo m_gradType;
RMSPropInfo m_rpi;
AdamInfo m_adam;
size_t m_numMBsToShowResult = 0;
size_t m_firstMBsToShowResult = 0;