Expose Adam in BS
This commit is contained in:
Родитель
93e10096cb
Коммит
b891ec0759
|
@ -1537,10 +1537,10 @@ void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients,
|
|||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::Adam(GPUMatrix<ElemType>& gradients,
|
||||
GPUMatrix<ElemType>& functionValues,
|
||||
ElemType learnRatePerSample,
|
||||
ElemType momentum,
|
||||
ElemType adaWeight,
|
||||
ElemType adaMul,
|
||||
ElemType learnRatePerSample, //alpha
|
||||
ElemType momentum, // /beta_1
|
||||
ElemType adaWeight, // /beta_2
|
||||
ElemType adaMul, //biasCorrection
|
||||
ElemType epsilon,
|
||||
ElemType unitGainFactor,
|
||||
bool adamax)
|
||||
|
|
|
@ -1827,7 +1827,11 @@ void Matrix<ElemType>::FSAdagradUpdate(Matrix<ElemType>& gradients, Matrix<ElemT
|
|||
///
|
||||
// Implement the original adam algorithm according to the paper
|
||||
// Ref: ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION, https://arxiv.org/pdf/1412.6980.pdf
|
||||
///
|
||||
///
|
||||
// Association between the method parameters and the paper notation:
|
||||
// smoothedCount - t
|
||||
// meanMomentum - /beta_1
|
||||
// varMomentum - /beta_2
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::AdamUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, const double smoothedCount,
|
||||
const double learnRatePerSample, const double meanMomentum, const double varMomentum, const double epsilon, ElemType unitGainFactor, bool adamax)
|
||||
|
|
|
@ -2429,6 +2429,11 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
|
|||
(ElemType) m_rpi.dec, (ElemType) m_rpi.min, needAveMultiplier, true);
|
||||
Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues);
|
||||
}
|
||||
else if (adpType == GradientsUpdateType::Adam)
|
||||
{
|
||||
smoothedGradientValues.AdamUpdate(gradientValues, functionValues, smoothedCount + 1, learnRatePerSample,
|
||||
m_adam.meanMomentum, m_adam.varMomentum, m_adam.epsilon, (ElemType)(1 - m_adam.meanMomentum), false);
|
||||
}
|
||||
|
||||
if (noiseStd > 0)
|
||||
{
|
||||
|
@ -2840,6 +2845,7 @@ static GradientsUpdateType ParseGradUpdateType(const wstring& s)
|
|||
else if (EqualCI(s, L"adagrad")) return GradientsUpdateType::AdaGrad;
|
||||
else if (EqualCI(s, L"rmsProp")) return GradientsUpdateType::RmsProp;
|
||||
else if (EqualCI(s, L"fsAdagrad")) return GradientsUpdateType::FSAdaGrad;
|
||||
else if (EqualCI(s, L"adam")) return GradientsUpdateType::Adam;
|
||||
// legacy, deprecated
|
||||
else if (EqualCI(s, L"normal") || EqualCI(s, L"simple")) return GradientsUpdateType::None;
|
||||
else InvalidArgument("ParseGradUpdateType: Invalid Gradient Updating Type. Valid values are (none | adagrad | rmsProp | fsAdagrad )");
|
||||
|
@ -3006,6 +3012,11 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
|||
m_rpi.max = configSGD(L"rms_wgt_max", 10.0);
|
||||
m_rpi.gamma = configSGD(L"rms_gamma", 0.99);
|
||||
|
||||
// Adam settings
|
||||
m_adam.meanMomentum = configSGD(L"adam_meanMomentum", 0.9);
|
||||
m_adam.varMomentum = configSGD(L"adam_varMomentum", 0.999);
|
||||
m_adam.epsilon = configSGD(L"adam_epsilon", pow(10, -8));
|
||||
|
||||
m_needAveMultiplier = configSGD(L"normWithAveMultiplier", true);
|
||||
m_L2RegWeight = configSGD(L"L2RegWeight", 0.0);
|
||||
m_L1RegWeight = configSGD(L"L1RegWeight", 0.0);
|
||||
|
|
|
@ -55,7 +55,8 @@ enum class GradientsUpdateType : int
|
|||
None,
|
||||
AdaGrad,
|
||||
RmsProp,
|
||||
FSAdaGrad
|
||||
FSAdaGrad,
|
||||
Adam
|
||||
};
|
||||
|
||||
// modelParallelSGD can be combined with dataParallelSGD/modelAveragingSGD/blockMomentumSGD
|
||||
|
@ -91,6 +92,21 @@ struct RMSPropInfo
|
|||
}
|
||||
};
|
||||
|
||||
struct AdamInfo
|
||||
{
|
||||
double meanMomentum; //beta_1
|
||||
double varMomentum; //beta_2
|
||||
double epsilon;
|
||||
|
||||
AdamInfo()
|
||||
{
|
||||
meanMomentum = 0.9;
|
||||
varMomentum = 0.999;
|
||||
epsilon = pow(10, -8);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct GradientUpdateInfo
|
||||
{
|
||||
GradientsUpdateType type = GradientsUpdateType::AdaGrad;
|
||||
|
@ -254,6 +270,7 @@ protected:
|
|||
|
||||
GradientUpdateInfo m_gradType;
|
||||
RMSPropInfo m_rpi;
|
||||
AdamInfo m_adam;
|
||||
|
||||
size_t m_numMBsToShowResult = 0;
|
||||
size_t m_firstMBsToShowResult = 0;
|
||||
|
|
Загрузка…
Ссылка в новой задаче