Add unit-gain momentum flag
This commit contains a breaking change: it adds a flag to both C++ and python Learner API that indicates if the momentum should be applied in the regular fashion or as a unit-gain filter. This flag is a required parameter for learners that use momentum (momentum_sgd, nesterov_sgd and adam_sgd).
This commit is contained in:
Родитель
43b30a1c41
Коммит
3ab246855b
|
@ -73,7 +73,9 @@ def convnet_cifar10(debug_output=False):
|
|||
l2_reg_weight = 0.002
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, l2_regularization_weight = l2_reg_weight)
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
|
|
|
@ -85,6 +85,7 @@ def convnet_cifar10_dataaug(reader_train, reader_test, max_epochs = 80):
|
|||
|
||||
# trainer object
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -93,7 +93,9 @@ def convnet_cifar10_dataaug(create_train_reader, test_reader, create_dist_learne
|
|||
|
||||
# trainer object
|
||||
learner = create_dist_learner(
|
||||
cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight))
|
||||
cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight=l2_reg_weight))
|
||||
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ def convnet_mnist(debug_output=False):
|
|||
mm_schedule = cntk.learner.momentum_as_time_constant_schedule(mm_time_constant, epoch_size)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule)
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, unit_gain=True)
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
|
|
|
@ -88,6 +88,7 @@ def train_and_evaluate(reader_train, reader_test, network_name, max_epochs):
|
|||
|
||||
# trainer object
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ def train_and_evaluate(create_train_reader, test_reader, network_name, max_epoch
|
|||
|
||||
# trainer object
|
||||
learner = create_dist_learner(momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight = l2_reg_weight))
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -83,6 +83,7 @@ def train(reader, model, max_epochs):
|
|||
lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size)
|
||||
learner = adam_sgd(z.parameters,
|
||||
lr=lr_per_sample, momentum=momentum_time_constant,
|
||||
unit_gain=True,
|
||||
low_memory=True,
|
||||
gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)
|
||||
|
||||
|
|
|
@ -158,6 +158,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
gradient_clipping_with_truncation = True
|
||||
learner = momentum_sgd(z.parameters,
|
||||
lr_per_minibatch, momentum_time_constant,
|
||||
unit_gain=True,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
@ -241,7 +242,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
ce = cross_entropy_with_softmax(z, label_sequence)
|
||||
errs = classification_error(z, label_sequence)
|
||||
trainer = Trainer(z, ce, errs, [momentum_sgd(
|
||||
z.parameters, lr_per_minibatch, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
||||
z.parameters, lr_per_minibatch, momentum_time_constant, True,
|
||||
clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
||||
|
||||
error2 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim)
|
||||
|
||||
|
|
|
@ -164,6 +164,7 @@ def train_lm(training_file):
|
|||
clipping_threshold_per_sample = 5.0
|
||||
gradient_clipping_with_truncation = True
|
||||
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant,
|
||||
unit_gain=True,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
|
|
@ -198,7 +198,7 @@ def conv3d_ucf11(train_reader, test_reader, max_epochs=30):
|
|||
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant, epoch_size=epoch_size)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule)
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule, True)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
|
|
|
@ -3531,6 +3531,7 @@ namespace CNTK
|
|||
CNTK_API LearnerPtr MomentumSGDLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
||||
///
|
||||
|
@ -3539,6 +3540,7 @@ namespace CNTK
|
|||
CNTK_API LearnerPtr NesterovLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
||||
static MomentumSchedule DefaultVarianceMomentum = MomentumAsTimeConstantSchedule(2 * 3600 * 100);
|
||||
|
@ -3549,6 +3551,7 @@ namespace CNTK
|
|||
CNTK_API LearnerPtr AdamLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
|
||||
bool lowMemory = true,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "Utils.h"
|
||||
#include "Serialization.h"
|
||||
|
||||
#define UPDATE_FUNCTION \
|
||||
#define DISPATCH_TO_TYPED_UPDATE_FUNCTION \
|
||||
switch (smoothedGradientValue->GetDataType()) \
|
||||
{ \
|
||||
case DataType::Float: \
|
||||
|
@ -22,6 +22,11 @@
|
|||
NOT_IMPLEMENTED; \
|
||||
}
|
||||
|
||||
#define GET_WRITABLE_MATRICES \
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue); \
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue); \
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameter.Value());
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
using namespace std;
|
||||
|
||||
|
@ -184,15 +189,13 @@ namespace CNTK
|
|||
LogicError("Learner parameters contain duplicates.");
|
||||
}
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
if (allocateSmoothGradients)
|
||||
{
|
||||
if (!allocateSmoothGradients)
|
||||
for (const auto& parameter : parameters)
|
||||
{
|
||||
continue;
|
||||
NDArrayViewPtr view = AllocateNDArrayView(parameter, parameter.Shape());
|
||||
m_smoothedGradientValues.emplace(parameter, view);
|
||||
}
|
||||
|
||||
NDArrayViewPtr view = AllocateNDArrayView(parameter, parameter.Shape());
|
||||
m_smoothedGradientValues.insert(make_pair(parameter, view));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -256,7 +259,7 @@ namespace CNTK
|
|||
Print(gradientValue, "Gradient Update");
|
||||
Print(smoothedGradientValue, "Smoothed Gradient Input");
|
||||
#endif
|
||||
UPDATE_FUNCTION;
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
|
||||
#if DUMPOUTPUT
|
||||
Print(parameter.Value(), "Parameter Update");
|
||||
|
@ -275,7 +278,8 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerBase::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
void LearnerBase::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
const auto& parameterValue = parameter.Value();
|
||||
PreProcess<ElementType>(parameterValue, gradientValue, trainingSampleCount);
|
||||
|
@ -364,27 +368,39 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
LearnerSGD::LearnerSGD(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
AdditionalLearningOptions additionalOptions,
|
||||
bool allocateSmoothGradients)
|
||||
: LearnerBase(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients)
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
if (!allocateSmoothGradients)
|
||||
{
|
||||
// the vanilla sgd does not need the smooth gradients per se,
|
||||
// insert dummy nd views instead.
|
||||
for (const auto& parameter : parameters)
|
||||
{
|
||||
m_smoothedGradientValues.emplace(parameter, AllocateNDArrayView(parameter, {}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
const auto& parameterValue = parameter.Value();
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
|
||||
UNUSED(smoothedGradientValue);
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameter.Value());
|
||||
const auto learningRate = ElementType(LearningRate(trainingSampleCount));
|
||||
const auto momentum = ElementType(MomentumValueForMB(trainingSampleCount));
|
||||
|
||||
// TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters
|
||||
// Also, come up with a better name for NormalGrad (Default? Regular? Plain?).
|
||||
// (one for vanilla SGD, the other for momentum SGD, and the third one for NAG).
|
||||
smoothedGradientMatrix->NormalGrad(*gradientMatrix, *parameterMatrix,
|
||||
learningRate, momentum, UseNesterovMomentum());
|
||||
parameterMatrix->SGDUpdate(*gradientMatrix, learningRate);
|
||||
}
|
||||
|
||||
double LearnerMomentumSGD::MomentumValueForMB(const MomentumSchedule& schedule, size_t minibatchSize) const
|
||||
|
@ -397,6 +413,44 @@ namespace CNTK
|
|||
return std::pow(currentMomentum, minibatchSize);
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerMomentumSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerMomentumSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
GET_WRITABLE_MATRICES;
|
||||
|
||||
const auto learningRate = ElementType(LearningRate(trainingSampleCount));
|
||||
const auto momentum = ElementType(MomentumValueForMB(trainingSampleCount));
|
||||
|
||||
parameterMatrix->MomentumSGDUpdate(*gradientMatrix, *smoothedGradientMatrix,
|
||||
learningRate, momentum, UseUnitGainMomentum());
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerNesterov::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerNesterov::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
GET_WRITABLE_MATRICES;
|
||||
|
||||
const auto learningRate = ElementType(LearningRate(trainingSampleCount));
|
||||
const auto momentum = ElementType(MomentumValueForMB(trainingSampleCount));
|
||||
|
||||
parameterMatrix->NesterovAcceleratedMomentumSGDUpdate(*gradientMatrix, *smoothedGradientMatrix,
|
||||
learningRate, momentum, UseUnitGainMomentum());
|
||||
}
|
||||
|
||||
LearnerAdaGrad::LearnerAdaGrad(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
bool needAveMultiplier,
|
||||
|
@ -416,24 +470,21 @@ namespace CNTK
|
|||
const auto shape = GetMatrixShape(parameter);
|
||||
NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], factor * shape[1] });
|
||||
|
||||
m_smoothedGradientValues.insert(make_pair(parameter, view));
|
||||
m_smoothedGradientValues.emplace(parameter, view);
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
/*virtual*/ void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
UNUSED(trainingSampleCount);
|
||||
|
||||
const auto& parameterValue = parameter.Value();
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
GET_WRITABLE_MATRICES
|
||||
|
||||
const auto learningRate = LearningRate(trainingSampleCount);
|
||||
|
||||
|
@ -446,32 +497,33 @@ namespace CNTK
|
|||
LearnerFSAdaGrad::LearnerFSAdaGrad(const vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
const MomentumSchedule& varianceMomentumSchedule,
|
||||
AdditionalLearningOptions additionalOptions)
|
||||
: LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule, additionalOptions, /*allocateSmoothGradients*/ false),
|
||||
: LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule,
|
||||
unitGain, additionalOptions, /*allocateSmoothGradients*/ false),
|
||||
m_varianceMomentumSchedule(varianceMomentumSchedule)
|
||||
{
|
||||
for (const auto& parameter : parameters)
|
||||
{
|
||||
const auto shape = GetMatrixShape(parameter);
|
||||
NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], 2 * shape[1] });
|
||||
m_smoothedGradientValues.insert(make_pair(parameter, view));
|
||||
m_smoothedCounts.insert(make_pair(parameter, 0.0));
|
||||
m_smoothedGradientValues.emplace(parameter, view);
|
||||
m_smoothedCounts.emplace(parameter, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
/*virtual*/ void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
const auto& parameterValue = parameter.Value();
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
GET_WRITABLE_MATRICES;
|
||||
|
||||
const auto learningRate = LearningRate(trainingSampleCount);
|
||||
const auto momentum = MomentumValueForMB(trainingSampleCount);
|
||||
|
@ -480,7 +532,8 @@ namespace CNTK
|
|||
|
||||
double& smoothedCount = m_smoothedCounts.at(parameter);
|
||||
|
||||
smoothedGradientMatrix->FSAdagradUpdate(trainingSampleCount, *gradientMatrix, *parameterMatrix, smoothedCount, learningRate, s_targetAdagradAvDenom, momentum, varMomentum);
|
||||
smoothedGradientMatrix->FSAdagradUpdate(trainingSampleCount, *gradientMatrix, *parameterMatrix, smoothedCount, learningRate,
|
||||
s_targetAdagradAvDenom, momentum, varMomentum, UseUnitGainMomentum());
|
||||
}
|
||||
|
||||
LearnerRMSProp::LearnerRMSProp(const vector<Parameter>& parameters,
|
||||
|
@ -503,24 +556,21 @@ namespace CNTK
|
|||
const auto shape = GetMatrixShape(parameter);
|
||||
NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], factor * shape[1] });
|
||||
|
||||
m_smoothedGradientValues.insert(make_pair(parameter, view));
|
||||
m_smoothedGradientValues.emplace(parameter, view);
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
/*virtual*/ void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
|
||||
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
|
||||
{
|
||||
UNUSED(trainingSampleCount);
|
||||
|
||||
const auto& parameterValue = parameter.Value();
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
GET_WRITABLE_MATRICES;
|
||||
|
||||
const auto learningRate = LearningRate(trainingSampleCount);
|
||||
|
||||
|
@ -548,22 +598,25 @@ namespace CNTK
|
|||
LearnerPtr MomentumSGDLearner(const vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRateSchedule, momentumSchedule, additionalOptions);
|
||||
return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRateSchedule, momentumSchedule, unitGain, additionalOptions);
|
||||
}
|
||||
|
||||
LearnerPtr NesterovLearner(const vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerNesterov>(parameters, learningRateSchedule, momentumSchedule, additionalOptions);
|
||||
return MakeSharedObject<LearnerNesterov>(parameters, learningRateSchedule, momentumSchedule, unitGain, additionalOptions);
|
||||
}
|
||||
|
||||
LearnerPtr AdamLearner(const vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
const MomentumSchedule& varianceMomentumSchedule, /*= MomentumAsTimeConstantSchedulePerSample(2 * 3600 * 100)*/
|
||||
bool lowMemory, /*= true*/
|
||||
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
|
||||
|
@ -572,7 +625,7 @@ namespace CNTK
|
|||
{
|
||||
LogicError("AdamLearner: only the low-memory variant is supported at the moment.");
|
||||
}
|
||||
return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRateSchedule, momentumSchedule, varianceMomentumSchedule, additionalOptions);
|
||||
return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRateSchedule, momentumSchedule, unitGain, varianceMomentumSchedule, additionalOptions);
|
||||
}
|
||||
|
||||
LearnerPtr AdaGradLearner(const vector<Parameter>& parameters,
|
||||
|
|
|
@ -108,26 +108,13 @@ namespace CNTK
|
|||
};
|
||||
|
||||
// Vanilla gradient descent optimization algorithm.
|
||||
class LearnerSGD : public LearnerBase
|
||||
class LearnerSGD final : public LearnerBase
|
||||
{
|
||||
public:
|
||||
LearnerSGD(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
AdditionalLearningOptions additionalOptions,
|
||||
bool allocateSmoothGradients = true)
|
||||
: LearnerBase(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients)
|
||||
{}
|
||||
|
||||
// TODO: get rid of this as soon as NormalGrad is refactored.
|
||||
virtual double MomentumValueForMB(size_t /*minibatchSize*/) const
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
virtual bool UseNesterovMomentum() const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool allocateSmoothGradients = false);
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -138,30 +125,45 @@ namespace CNTK
|
|||
};
|
||||
|
||||
// SGD optimization with momentum.
|
||||
class LearnerMomentumSGD : public LearnerSGD
|
||||
class LearnerMomentumSGD : public LearnerBase
|
||||
{
|
||||
public:
|
||||
LearnerMomentumSGD(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions,
|
||||
bool allocateSmoothGradients = true)
|
||||
: LearnerSGD(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients),
|
||||
m_momentumSchedule(momentumSchedule)
|
||||
: LearnerBase(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients),
|
||||
m_momentumSchedule(momentumSchedule),
|
||||
m_unitGain(unitGain)
|
||||
{ }
|
||||
|
||||
// returns current per-minibatch momentum value.
|
||||
virtual double MomentumValueForMB(size_t minibatchSize) const override
|
||||
virtual double MomentumValueForMB(size_t minibatchSize) const
|
||||
{
|
||||
return MomentumValueForMB(m_momentumSchedule, minibatchSize);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const override;
|
||||
|
||||
template <typename ElementType>
|
||||
void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const;
|
||||
|
||||
// returns current per-minibatch momentum value from the provided schedule.
|
||||
double MomentumValueForMB(const MomentumSchedule& schedule, size_t minibatchSize) const;
|
||||
|
||||
// Return true if the update should use classic momentum and
|
||||
// false if the unit-gain momentum should be used instead.
|
||||
bool UseUnitGainMomentum() const
|
||||
{
|
||||
return m_unitGain;
|
||||
}
|
||||
|
||||
private:
|
||||
MomentumSchedule m_momentumSchedule;
|
||||
bool m_unitGain;
|
||||
};
|
||||
|
||||
// Nesterov's accelerated SGDLearnerBase descent.
|
||||
|
@ -172,14 +174,16 @@ namespace CNTK
|
|||
LearnerNesterov(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions)
|
||||
: LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule, additionalOptions, /*allocateSmoothGradients*/ true)
|
||||
: LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule, unitGain, additionalOptions, /*allocateSmoothGradients*/ true)
|
||||
{}
|
||||
|
||||
virtual bool UseNesterovMomentum() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
protected:
|
||||
virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const override;
|
||||
|
||||
template <typename ElementType>
|
||||
void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const;
|
||||
};
|
||||
|
||||
class LearnerAdaGrad : public LearnerBase
|
||||
|
@ -206,6 +210,7 @@ namespace CNTK
|
|||
LearnerFSAdaGrad(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
const MomentumSchedule& varianceMomentumSchedule,
|
||||
AdditionalLearningOptions additionalOptions);
|
||||
|
||||
|
|
|
@ -1198,8 +1198,11 @@ void CPUMatrix<ElemType>::FSAdagrad(CPUMatrix<ElemType>& gradients,
|
|||
ElemType learnRatePerSample,
|
||||
ElemType momentum,
|
||||
ElemType adaWeight,
|
||||
ElemType adaMul)
|
||||
ElemType adaMul,
|
||||
bool unitGainMomentum)
|
||||
{
|
||||
auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
|
||||
|
||||
size_t numColsNeeded = 2 * gradients.GetNumCols();
|
||||
|
||||
if (IsEmpty() || (GetNumCols() < numColsNeeded))
|
||||
|
@ -1234,7 +1237,7 @@ void CPUMatrix<ElemType>::FSAdagrad(CPUMatrix<ElemType>& gradients,
|
|||
|
||||
if (momentum > 0.0f)
|
||||
{
|
||||
g = momentum * smoothMom[i] + (1.0f - momentum) * g;
|
||||
g = momentum * smoothMom[i] + unitGainFactor * g;
|
||||
smoothMom[i] = g;
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,10 @@ public:
|
|||
CPUMatrix<ElemType> Diagonal() const;
|
||||
|
||||
ElemType Adagrad(CPUMatrix<ElemType>& gradients, const bool needAveMultiplier);
|
||||
void FSAdagrad(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul);
|
||||
|
||||
void FSAdagrad(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample,
|
||||
ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
|
||||
|
||||
ElemType RmsProp(CPUMatrix<ElemType>& gradients,
|
||||
ElemType RMS_GAMMA,
|
||||
ElemType RMS_WGT_INC,
|
||||
|
|
|
@ -1126,11 +1126,19 @@ template <class ElemType>
|
|||
return result;
|
||||
}
|
||||
|
||||
// normal update for smoothed gradients c and current gradients (this)
|
||||
// TODO: comment seems wrong; cf. SGD.cpp: smoothedGradient.NormalGrad(gradientValues, functionValues,...)
|
||||
// A helper method used in MomentumSGDUpdate and NesterovAcceleratedMomentumSGDUpdate.
|
||||
// Modifies the smoothed gradients "c", as well as the current gradients "this" on which this method is invoked.
|
||||
// Classic momentum (unitGainFactor == 1.0):
|
||||
// 1) c = momentum * c + this
|
||||
// Unit-gain momentum (unitGainFactor == 1.0 - momentum):
|
||||
// 1) c = momentum * c + (1.0 - momentum) * this
|
||||
// 2) this = c
|
||||
// TODO: NormalGrad is a misnomer here. Come up with a better name.
|
||||
template <class ElemType>
|
||||
void CPUSparseMatrix<ElemType>::NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum)
|
||||
void CPUSparseMatrix<ElemType>::NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum)
|
||||
{
|
||||
const auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
|
||||
|
||||
if (c.IsEmpty())
|
||||
{
|
||||
c.RequireSize(GetNumRows(), GetNumCols());
|
||||
|
@ -1140,17 +1148,18 @@ void CPUSparseMatrix<ElemType>::NormalGrad(CPUMatrix<ElemType>& c, const ElemTyp
|
|||
|
||||
if (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol || GetFormat() == MatrixFormat::matrixFormatSparseBlockRow)
|
||||
{
|
||||
const auto isSparseBlockCol = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol);
|
||||
for (size_t j = 0; j < GetBlockSize(); j++)
|
||||
{
|
||||
size_t i = GetBlockIds()[j] - GetBlockIdShift();
|
||||
size_t len = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) ? GetNumRows() : GetNumCols();
|
||||
size_t len = (isSparseBlockCol) ? GetNumRows() : GetNumCols();
|
||||
size_t start = j * len;
|
||||
for (size_t p = start; p < start + len; p++)
|
||||
{
|
||||
ElemType val = Buffer()[p];
|
||||
size_t row = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) ? (p - start) : i;
|
||||
size_t col = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) ? i : (p - start);
|
||||
c(row, col) = (1 - momentum) * val + momentum * c(row, col);
|
||||
size_t row = (isSparseBlockCol) ? (p - start) : i;
|
||||
size_t col = (isSparseBlockCol) ? i : (p - start);
|
||||
c(row, col) = unitGainFactor * val + momentum * c(row, col);
|
||||
Buffer()[p] = c(row, col);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -217,7 +217,7 @@ public:
|
|||
}
|
||||
|
||||
public:
|
||||
void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum);
|
||||
void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum = true);
|
||||
ElemType Adagrad(CPUMatrix<ElemType>& c, const bool needAveMultiplier);
|
||||
|
||||
public:
|
||||
|
|
|
@ -1394,7 +1394,8 @@ void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients,
|
|||
ElemType learnRatePerSample,
|
||||
ElemType momentum,
|
||||
ElemType adaWeight,
|
||||
ElemType adaMul)
|
||||
ElemType adaMul,
|
||||
bool unitGainMomentum)
|
||||
{
|
||||
size_t numColsNeeded = 2 * gradients.GetNumCols();
|
||||
|
||||
|
@ -1409,7 +1410,7 @@ void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients,
|
|||
size_t n = gradients.GetNumElements();
|
||||
int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock;
|
||||
_fsadagrad<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock>>>(n, gradients.Data(), Data(), Data()+ n, functionValues.Data(),
|
||||
learnRatePerSample, momentum, adaWeight, adaMul);
|
||||
learnRatePerSample, momentum, adaWeight, adaMul, unitGainMomentum);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -224,8 +224,17 @@ public:
|
|||
}
|
||||
|
||||
ElemType Adagrad(GPUMatrix<ElemType>& gradients, const bool needAveMultiplier);
|
||||
void FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul);
|
||||
ElemType RmsProp(GPUMatrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
|
||||
|
||||
void FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample,
|
||||
ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
|
||||
|
||||
ElemType RmsProp(GPUMatrix<ElemType>& gradients,
|
||||
ElemType RMS_GAMMA,
|
||||
ElemType RMS_WGT_INC,
|
||||
ElemType RMS_WGT_MAX,
|
||||
ElemType RMS_WGT_DEC,
|
||||
ElemType RMS_WGT_MIN,
|
||||
const bool needAveMultiplier);
|
||||
|
||||
void Reshape(const size_t numRows, const size_t numCols);
|
||||
|
||||
|
|
|
@ -1421,8 +1421,9 @@ __global__ void _adagrad4BlockSparse(
|
|||
|
||||
template <class ElemType>
|
||||
__global__ void _fsadagrad(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, ElemType* smoothMom, ElemType* val,
|
||||
ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul)
|
||||
ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum)
|
||||
{
|
||||
const ElemType unitGainFactor = unitGainMomentum ? (1.0 - mom) : 1.0;
|
||||
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
CUDA_LONG stride = blockDim.x * gridDim.x;
|
||||
for (; idx < size; idx += stride)
|
||||
|
@ -1449,7 +1450,7 @@ __global__ void _fsadagrad(CUDA_LONG size, ElemType* grad, ElemType* smoothAda,
|
|||
|
||||
if (mom > 0.0f)
|
||||
{
|
||||
g = mom * smoothMom[idx] + (1.0f - mom) * g;
|
||||
g = mom * smoothMom[idx] + unitGainFactor * g;
|
||||
smoothMom[idx] = g;
|
||||
}
|
||||
|
||||
|
@ -1483,8 +1484,9 @@ template <class ElemType>
|
|||
__global__ void _fsadagrad4BlockSparseCol(CUDA_LONG size,
|
||||
ElemType* grad_bsc, const GPUSPARSE_INDEX_TYPE* colOrRow2blockId, const size_t len,
|
||||
ElemType* smoothAda, ElemType* smoothMom, ElemType* val,
|
||||
ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul)
|
||||
ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum)
|
||||
{
|
||||
const ElemType unitGainFactor = unitGainMomentum ? (1.0 - mom) : 1.0;
|
||||
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
CUDA_LONG stride = blockDim.x * gridDim.x;
|
||||
for (; idx < size; idx += stride)
|
||||
|
@ -1511,7 +1513,7 @@ __global__ void _fsadagrad4BlockSparseCol(CUDA_LONG size,
|
|||
|
||||
if (mom > 0.0f)
|
||||
{
|
||||
g = mom * smoothMom[idx] + (1.0f - mom) * g;
|
||||
g = mom * smoothMom[idx] + unitGainFactor * g;
|
||||
smoothMom[idx] = g;
|
||||
}
|
||||
|
||||
|
@ -3980,8 +3982,10 @@ __global__ void _normalGradForSparseBlock(
|
|||
const size_t numBlocks,
|
||||
ElemType* lhsValues, // lhs is blockCol or blockRow
|
||||
const GPUSPARSE_INDEX_TYPE* blockIds,
|
||||
ElemType* rhs)
|
||||
ElemType* rhs,
|
||||
bool unitGainMomentum)
|
||||
{
|
||||
const ElemType unitGainFactor = unitGainMomentum ? (1.0 - momentum) : 1.0;
|
||||
const CUDA_LONG index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
CUDA_LONG row, col;
|
||||
if (blockCol)
|
||||
|
@ -4000,7 +4004,7 @@ __global__ void _normalGradForSparseBlock(
|
|||
col = index - numCols * blockId;
|
||||
row = blockIds[blockId];
|
||||
}
|
||||
rhs[IDX2C(row, col, numRows)] = (1 - momentum) * lhsValues[index] + momentum * rhs[IDX2C(row, col, numRows)];
|
||||
rhs[IDX2C(row, col, numRows)] = unitGainFactor * lhsValues[index] + momentum * rhs[IDX2C(row, col, numRows)];
|
||||
lhsValues[index] = rhs[IDX2C(row, col, numRows)];
|
||||
}
|
||||
|
||||
|
|
|
@ -1413,9 +1413,16 @@ GPUSparseMatrix<ElemType>& GPUSparseMatrix<ElemType>::InplaceSoftThreshold(const
|
|||
return *this;
|
||||
}
|
||||
|
||||
// normal update for smoothed gradients c and current gradients (this)
|
||||
// A helper method used in MomentumSGDUpdate and NesterovAcceleratedMomentumSGDUpdate.
|
||||
// Modifies the smoothed gradients "c", as well as the current gradients "this" on which this method is invoked.
|
||||
// Classic momentum (unitGainFactor == 1.0):
|
||||
// 1) c = momentum * c + this
|
||||
// Unit-gain momentum (unitGainFactor == 1.0 - momentum):
|
||||
// 1) c = momentum * c + (1.0 - momentum) * this
|
||||
// 2) this = c
|
||||
// TODO: NormalGrad is a misnomer here. Come up with a better name.
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum)
|
||||
void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum)
|
||||
{
|
||||
VerifyWritable(__FUNCTION__);
|
||||
|
||||
|
@ -1440,7 +1447,8 @@ void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemTyp
|
|||
GetBlockSize(),
|
||||
Data(),
|
||||
BlockId2ColOrRow(),
|
||||
c.Data());
|
||||
c.Data(),
|
||||
unitGainMomentum);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1512,7 +1520,8 @@ void GPUSparseMatrix<ElemType>::FSAdagrad(
|
|||
ElemType learnRatePerSample,
|
||||
ElemType momentum,
|
||||
ElemType adaWeight,
|
||||
ElemType adaMul)
|
||||
ElemType adaMul,
|
||||
bool unitGainMomentum)
|
||||
{
|
||||
if (GetFormat() != MatrixFormat::matrixFormatSparseBlockCol)
|
||||
{
|
||||
|
@ -1534,7 +1543,7 @@ void GPUSparseMatrix<ElemType>::FSAdagrad(
|
|||
_fsadagrad4BlockSparseCol<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(
|
||||
n, Data(), ColOrRow2BlockId(), GetNumRows(),
|
||||
c.Data(), c.Data() + n, functionValues.Data(),
|
||||
learnRatePerSample, momentum, adaWeight, adaMul);
|
||||
learnRatePerSample, momentum, adaWeight, adaMul, unitGainMomentum);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -408,9 +408,9 @@ public:
|
|||
const bool transposeB, ElemType beta, GPUMatrix<ElemType>& c, size_t numChannels, size_t horizontalSubsample, bool padding, bool channelwise);
|
||||
static void TensorShuffleScaleAndAdd(ElemType keepWeight, const GPUSparseMatrix<ElemType>& a, size_t D, size_t S, size_t M, size_t K, size_t T, ElemType scaleFactor, const GPUSparseMatrix<ElemType>& b, GPUSparseMatrix<ElemType>& c);
|
||||
|
||||
void NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum);
|
||||
void NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum = true);
|
||||
ElemType Adagrad(GPUMatrix<ElemType>& c, const bool needAveMultiplier);
|
||||
void FSAdagrad(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul);
|
||||
void FSAdagrad(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
|
||||
ElemType RmsProp(GPUMatrix<ElemType>& c, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
|
||||
|
||||
static void Multiply(const GPUSparseMatrix<ElemType>& S, const GPUMatrix<ElemType>& D, GPUMatrix<ElemType>& C);
|
||||
|
|
|
@ -1484,70 +1484,137 @@ void Matrix<ElemType>::SetUniformRandomMask(const ElemType maskRate, const ElemT
|
|||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
// Vanilla SGD update.
|
||||
// Modifies "this" parameter matrix, on which this method is invoked.
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::NormalGrad(Matrix<ElemType>& gradients,
|
||||
Matrix<ElemType>& functionValues,
|
||||
const ElemType learnRatePerSample,
|
||||
const ElemType momentum,
|
||||
const bool useNesterovMomentum)
|
||||
void Matrix<ElemType>::SGDUpdate(Matrix<ElemType>& gradients, ElemType learnRatePerSample)
|
||||
{
|
||||
DecideAndMoveToRightDevice(*this, gradients, functionValues);
|
||||
DecideAndMoveToRightDevice(gradients, *this);
|
||||
|
||||
if (!useNesterovMomentum)
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
|
||||
{
|
||||
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this);
|
||||
functionValues -= *this;
|
||||
},
|
||||
{
|
||||
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this);
|
||||
functionValues -= *this;
|
||||
},
|
||||
{
|
||||
if (momentum != 0) gradients.m_CPUSparseMatrix->NormalGrad(*m_CPUMatrix, momentum);
|
||||
ScaleAndAdd(-learnRatePerSample, gradients, functionValues);
|
||||
},
|
||||
{
|
||||
if (momentum != 0) gradients.m_GPUSparseMatrix->NormalGrad(*m_GPUMatrix, momentum);
|
||||
ScaleAndAdd(-learnRatePerSample, gradients, functionValues);
|
||||
});
|
||||
}
|
||||
else
|
||||
// w_t = w_{t-1} - learnRatePerSample * g_{t-1},
|
||||
ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this);
|
||||
},
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
|
||||
{ /* CPU dense */
|
||||
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this);
|
||||
ScaleAndAdd(-momentum, *this, functionValues);
|
||||
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradients, functionValues);
|
||||
// w_t = w_{t-1} - momentum * v_ {t-1} - (1-momentum)*learnRatePerSampele*gardient,
|
||||
},
|
||||
{ /* GPU dense */
|
||||
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this);
|
||||
ScaleAndAdd(-momentum, *this, functionValues);
|
||||
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradients, functionValues);
|
||||
},
|
||||
{ /* CPU sparse */
|
||||
if (momentum != 0)
|
||||
{
|
||||
Matrix<ElemType> gradientCache(gradients.GetDeviceId());
|
||||
gradientCache.AssignValuesOf(gradients);
|
||||
gradients.m_CPUSparseMatrix->NormalGrad(*m_CPUMatrix, momentum);
|
||||
ScaleAndAdd(-momentum, *this, functionValues);
|
||||
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradientCache, functionValues);
|
||||
}
|
||||
},
|
||||
{ /* GPU sparse */
|
||||
if (momentum != 0)
|
||||
{
|
||||
Matrix<ElemType> gradientCache(gradients.GetDeviceId());
|
||||
gradientCache.AssignValuesOf(gradients);
|
||||
gradients.m_GPUSparseMatrix->NormalGrad(*m_GPUMatrix, momentum);
|
||||
ScaleAndAdd(-momentum, *this, functionValues);
|
||||
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradientCache, functionValues);
|
||||
}
|
||||
});
|
||||
}
|
||||
// BUGBUG: cannot call ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this) here,
|
||||
// it produces different results from the scale and add below.
|
||||
// g'_{t-1} = learnRatePerSample * g_{t-1}
|
||||
// w_t = w_{t-1} - g'_{t-1}
|
||||
Scale(ElemType(learnRatePerSample), gradients);
|
||||
*this -= gradients;
|
||||
},
|
||||
{
|
||||
ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this);
|
||||
},
|
||||
{
|
||||
ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// SGD update with momentum.
|
||||
// Modifies "this" parameter matrix, on which this method is invoked.
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::MomentumSGDUpdate(Matrix<ElemType>& gradients,
|
||||
Matrix<ElemType>& smoothedGradients,
|
||||
ElemType learnRatePerSample,
|
||||
ElemType momentum,
|
||||
bool unitGainMomentum)
|
||||
{
|
||||
DecideAndMoveToRightDevice(smoothedGradients, gradients, *this);
|
||||
|
||||
const auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
|
||||
{
|
||||
// Classic momentum (unitGainFactor == 1.0):
|
||||
// 1) sg_t = momentum * sg_{t-1} + learnRatePerSample * g_{t-1}
|
||||
// Unit-gain momentum (unitGainFactor == 1.0 - momentum):
|
||||
// 1) sg_t = momentum * sg_{t-1} + learnRatePerSample * (1.0 - momentum) * g_{t-1}
|
||||
// 2) w_t = w_{t-1} - sg_t
|
||||
ScaleAndAdd(unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
|
||||
*this -= smoothedGradients;
|
||||
},
|
||||
{
|
||||
ScaleAndAdd(unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
|
||||
*this -= smoothedGradients;
|
||||
},
|
||||
{
|
||||
// The sparse update is slightly different from the dense implementation above:
|
||||
// Classic momentum (unitGainFactor == 1.0):
|
||||
// 1) sg_t = momentum * sg_{t-1} + g_{t-1}
|
||||
// Unit-gain momentum (unitGainFactor == 1.0 - momentum):
|
||||
// 1) sg_t = momentum * sg_{t-1} + (1.0 - momentum) * g_{t-1}
|
||||
// 2) g'_{t-1} = sg_t
|
||||
// 3) w_t = w_{t-1} - learnRatePerSample * g'_{t-1}
|
||||
if (momentum != 0)
|
||||
{
|
||||
gradients.m_CPUSparseMatrix->NormalGrad(*smoothedGradients.m_CPUMatrix, momentum, unitGainMomentum);
|
||||
}
|
||||
ScaleAndAdd(-learnRatePerSample, gradients, *this);
|
||||
},
|
||||
{
|
||||
if (momentum != 0)
|
||||
{
|
||||
gradients.m_GPUSparseMatrix->NormalGrad(*smoothedGradients.m_GPUMatrix, momentum, unitGainMomentum);
|
||||
}
|
||||
ScaleAndAdd(-learnRatePerSample, gradients, *this);
|
||||
});
|
||||
}
|
||||
|
||||
// Nesterov accelerated SGD update.
|
||||
// Modifies "this" parameter matrix, on which this method is invoked.
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::NesterovAcceleratedMomentumSGDUpdate(Matrix<ElemType>& gradients,
|
||||
Matrix<ElemType>& smoothedGradients,
|
||||
ElemType learnRatePerSample,
|
||||
ElemType momentum,
|
||||
bool unitGainMomentum)
|
||||
{
|
||||
DecideAndMoveToRightDevice(smoothedGradients, gradients, *this);
|
||||
|
||||
const auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
|
||||
{ /* CPU dense */
|
||||
// 1) sg_t = momentum * sg_{t-1} + learnRatePerSample * unitGainFactor * g_{t-1}
|
||||
// 2) w'_t = w_{t-1} - momentum * sg_t
|
||||
// 3) w_t = w'_t - learnRatePerSample * unitGainFactor * g_{t-1}
|
||||
// The end result:
|
||||
// w_t = w_{t-1} - momentum^2 * sg_{t-1} - learnRatePerSample * unitGainFactor * (1 + momentum) * g_{t-1}
|
||||
// sg_t = momentum * sg_{t-1} + learnRatePerSample * unitGainFactor * g_{t-1}
|
||||
ScaleAndAdd( unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
|
||||
ScaleAndAdd(-momentum, smoothedGradients, *this);
|
||||
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradients, *this);
|
||||
},
|
||||
{ /* GPU dense */
|
||||
ScaleAndAdd(unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
|
||||
ScaleAndAdd(-momentum, smoothedGradients, *this);
|
||||
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradients, *this);
|
||||
},
|
||||
{ /* CPU sparse */
|
||||
if (momentum != 0)
|
||||
{
|
||||
// Identical to the above, except that as a side effect "NormalGrad" modifies
|
||||
// gradient values in place, so that gradientCache is needed to store the original values.
|
||||
Matrix<ElemType> gradientCache(gradients.GetDeviceId());
|
||||
gradientCache.AssignValuesOf(gradients);
|
||||
gradients.m_CPUSparseMatrix->NormalGrad(*smoothedGradients.m_CPUMatrix, momentum, unitGainMomentum);
|
||||
ScaleAndAdd(-momentum, smoothedGradients, *this);
|
||||
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradientCache, *this);
|
||||
}
|
||||
},
|
||||
{ /* GPU sparse */
|
||||
if (momentum != 0)
|
||||
{
|
||||
Matrix<ElemType> gradientCache(gradients.GetDeviceId());
|
||||
gradientCache.AssignValuesOf(gradients);
|
||||
gradients.m_GPUSparseMatrix->NormalGrad(*smoothedGradients.m_GPUMatrix, momentum, unitGainMomentum);
|
||||
ScaleAndAdd(-momentum, smoothedGradients, *this);
|
||||
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradientCache, *this);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// both 'this' and gradients will be changed
|
||||
|
@ -1575,7 +1642,7 @@ template <class ElemType>
|
|||
void Matrix<ElemType>::FSAdagradUpdate(size_t mbSize,
|
||||
Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, double& smoothedCount,
|
||||
const double learnRatePerSample, const double targetAdagradAvDenom,
|
||||
const double meanMomentum, const double varMomentum)
|
||||
const double meanMomentum, const double varMomentum, bool unitGainMomentum)
|
||||
{
|
||||
// keep track on how many samples have been accumulated into the g^2 accumulator
|
||||
smoothedCount = varMomentum * smoothedCount + (1.0 - varMomentum) * mbSize;
|
||||
|
@ -1587,10 +1654,20 @@ void Matrix<ElemType>::FSAdagradUpdate(size_t mbSize,
|
|||
let targetAdagradAvDenom_x_sqrtAdagradSqrFrames = (ElemType)(targetAdagradAvDenom * sqrt(smoothedCount));
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradients, &gradients,
|
||||
{ m_CPUMatrix->FSAdagrad(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames); SetDataLocation(CPU); },
|
||||
{ m_GPUMatrix->FSAdagrad(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames); SetDataLocation(GPU); },
|
||||
{
|
||||
m_CPUMatrix->FSAdagrad(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix,
|
||||
(ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum,
|
||||
targetAdagradAvDenom_x_sqrtAdagradSqrFrames, unitGainMomentum);
|
||||
SetDataLocation(CPU);
|
||||
},
|
||||
{
|
||||
m_GPUMatrix->FSAdagrad(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix,
|
||||
(ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum,
|
||||
targetAdagradAvDenom_x_sqrtAdagradSqrFrames, unitGainMomentum);
|
||||
SetDataLocation(GPU);
|
||||
},
|
||||
{ NOT_IMPLEMENTED; },
|
||||
{ gradients.m_GPUSparseMatrix->FSAdagrad(*m_GPUMatrix, *functionValues.m_GPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames); SetDataLocation(GPU); });
|
||||
{ gradients.m_GPUSparseMatrix->FSAdagrad(*m_GPUMatrix, *functionValues.m_GPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames, unitGainMomentum); SetDataLocation(GPU); });
|
||||
|
||||
// Note: Since both 'this' and gradients are changed, we must call SetDataLocation() on 'this' as well.
|
||||
}
|
||||
|
@ -4796,6 +4873,9 @@ template <class ElemType>
|
|||
{
|
||||
ScaleAndAdd(alpha / beta, a, c); // c1=alpha/beta * a + c
|
||||
Scale(beta, c); // c/beta * beta
|
||||
// TODO: two lines above should be changed as follows:
|
||||
// Scale(beta, c); // c1 = c * beta
|
||||
// ScaleAndAdd(alpha, a, c); // c=alpha * a + c1 = alpha * a + beta * c
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -206,13 +206,15 @@ public:
|
|||
Matrix<ElemType> Diagonal() const;
|
||||
void AssignDiagonalValuesTo(Matrix<ElemType>& diag) const;
|
||||
|
||||
// TODO: all these scalars should be passed as doubles and cast down inside
|
||||
void NormalGrad(Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, const ElemType learnRatePerSample, const ElemType momentum, const bool useNAG);
|
||||
void SGDUpdate(Matrix<ElemType>& gradients, ElemType learnRatePerSample);
|
||||
void MomentumSGDUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& smoothedGradients, ElemType learnRatePerSample, ElemType momentum, bool unitGainMomentum = true);
|
||||
void NesterovAcceleratedMomentumSGDUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& smoothedGradients, ElemType learnRatePerSample, ElemType momentum, bool unitGainMomentum = true);
|
||||
|
||||
ElemType Adagrad(Matrix<ElemType>& gradients, const bool needAveMultiplier);
|
||||
void FSAdagradUpdate(size_t mbSize,
|
||||
Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, double& smoothedCount,
|
||||
const double learnRatePerSample, const double targetAdagradAvDenom,
|
||||
const double meanMomentum, const double varMomentum);
|
||||
const double meanMomentum, const double varMomentum, bool unitGainMomentum = true);
|
||||
ElemType RmsProp(Matrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
|
||||
|
||||
void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 10000, bool growOnly = true); // by default we only reallocate if need to grow
|
||||
|
|
|
@ -247,7 +247,7 @@ GPUSparseMatrix<ElemType>& GPUSparseMatrix<ElemType>::InplaceTruncate(const Elem
|
|||
|
||||
// normal update for smoothed gradients c and current gradients (this)
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum)
|
||||
void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum)
|
||||
{
|
||||
}
|
||||
template <class ElemType>
|
||||
|
@ -257,7 +257,7 @@ ElemType GPUSparseMatrix<ElemType>::Adagrad(GPUMatrix<ElemType>& c, const bool n
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>&, GPUMatrix<ElemType>&, ElemType, ElemType, ElemType, ElemType)
|
||||
void GPUSparseMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>&, GPUMatrix<ElemType>&, ElemType, ElemType, ElemType, ElemType, bool)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -1068,7 +1068,7 @@ ElemType GPUMatrix<ElemType>::Adagrad(GPUMatrix<ElemType>& gradients, const bool
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul)
|
||||
void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -2149,7 +2149,7 @@ void SGD<ElemType>::InitModelAggregationHandler(int traceLevel, DEVICEID_TYPE de
|
|||
// UpdateWeights() - actual weight update, implementing various update rules
|
||||
template <class ElemType>
|
||||
void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& smoothedGradient, double& smoothedCount,
|
||||
Matrix<ElemType>& smoothedGradientValues, double& smoothedCount,
|
||||
const double learnRatePerSample, const double momentumPerSample,
|
||||
size_t actualMBSize,
|
||||
const double L2RegWeight, const double L1RegWeight,
|
||||
|
@ -2164,7 +2164,7 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
|
|||
LOGPRINTF(stderr, "GradUpdateType()=%d, GradientUpdateNoiseStd()=%0.8f\n",
|
||||
GradUpdateType(), GradientUpdateNoiseStd());
|
||||
gradientValues.Print("Gradient Input");
|
||||
smoothedGradient.Print("Smoothed Gradient Input");
|
||||
smoothedGradientValues.Print("Smoothed Gradient Input");
|
||||
#endif
|
||||
|
||||
// make actualMBSize is a valid value
|
||||
|
@ -2194,12 +2194,23 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
|
|||
|
||||
if (adpType == GradientsUpdateType::None)
|
||||
{
|
||||
smoothedGradient.NormalGrad(gradientValues, functionValues,
|
||||
(ElemType) learnRatePerSample, (ElemType) momentum, useNesterovMomentum);
|
||||
// even if momentum is 0.0, still need to call a momentum-based update to store
|
||||
// [learning rate * current gradient values] in the smoothed gradients, in case
|
||||
// the momentum value for the next epoch is non-zero.
|
||||
if (!useNesterovMomentum)
|
||||
{
|
||||
functionValues.MomentumSGDUpdate(gradientValues, smoothedGradientValues,
|
||||
ElemType(learnRatePerSample), ElemType(momentum));
|
||||
}
|
||||
else
|
||||
{
|
||||
functionValues.NesterovAcceleratedMomentumSGDUpdate(gradientValues, smoothedGradientValues,
|
||||
ElemType(learnRatePerSample), ElemType(momentum));
|
||||
}
|
||||
}
|
||||
else if (adpType == GradientsUpdateType::AdaGrad)
|
||||
{
|
||||
double aveMultiplier = smoothedGradient.Adagrad(gradientValues, needAveMultiplier);
|
||||
double aveMultiplier = smoothedGradientValues.Adagrad(gradientValues, needAveMultiplier);
|
||||
Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues);
|
||||
}
|
||||
else if (adpType == GradientsUpdateType::FSAdaGrad)
|
||||
|
@ -2209,14 +2220,14 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
|
|||
static double smoothedCount = 0;
|
||||
#endif
|
||||
|
||||
smoothedGradient.FSAdagradUpdate(actualMBSize,
|
||||
smoothedGradientValues.FSAdagradUpdate(actualMBSize,
|
||||
gradientValues, functionValues, smoothedCount,
|
||||
learnRatePerSample, m_gradType.targetAdagradAvDenom,
|
||||
momentum, varMomentum);
|
||||
}
|
||||
else if (adpType == GradientsUpdateType::RmsProp)
|
||||
{
|
||||
double aveMultiplier = smoothedGradient.RmsProp(gradientValues, (ElemType) m_rpi.gamma,
|
||||
double aveMultiplier = smoothedGradientValues.RmsProp(gradientValues, (ElemType) m_rpi.gamma,
|
||||
(ElemType) m_rpi.inc, (ElemType) m_rpi.max,
|
||||
(ElemType) m_rpi.dec, (ElemType) m_rpi.min, needAveMultiplier);
|
||||
Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues);
|
||||
|
@ -2299,8 +2310,8 @@ void SGD<ElemType>::SaveCheckPointInfo(const size_t epoch, const size_t totalSam
|
|||
|
||||
for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++)
|
||||
{
|
||||
const Matrix<ElemType>& smoothedGradient = *smoothedGradientIter;
|
||||
fstream << smoothedGradient;
|
||||
const Matrix<ElemType>& smoothedGradientValues = *smoothedGradientIter;
|
||||
fstream << smoothedGradientValues;
|
||||
}
|
||||
|
||||
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EGradient");
|
||||
|
@ -2395,8 +2406,8 @@ void SGD<ElemType>::LoadCheckPointInfo(const size_t epochNumber,
|
|||
|
||||
for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++)
|
||||
{
|
||||
Matrix<ElemType>& smoothedGradient = *smoothedGradientIter;
|
||||
fstream >> smoothedGradient;
|
||||
Matrix<ElemType>& smoothedGradientValues = *smoothedGradientIter;
|
||||
fstream >> smoothedGradientValues;
|
||||
}
|
||||
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EGradient");
|
||||
|
||||
|
|
|
@ -394,6 +394,7 @@ class Test:
|
|||
ccMajorByCard = {
|
||||
'GeForce GTX 780 Ti': 3,
|
||||
'GeForce GTX 960': 5,
|
||||
'Quadro K2000' : 3,
|
||||
'Quadro M2000M': 5,
|
||||
'Quadro M4000': 5,
|
||||
}
|
||||
|
|
|
@ -6410,6 +6410,16 @@ Test module "MathTests" has passed with:
|
|||
Test case "MatrixUnitTests/MatrixAssignNumOfDiff" has passed with:
|
||||
2 assertions out of 2 passed
|
||||
|
||||
Test case "MatrixUnitTests/MatrixScale" has passed
|
||||
|
||||
Test case "MatrixUnitTests/MatrixSGDUpdate" has passed
|
||||
|
||||
Test case "MatrixUnitTests/MatrixMomentumSGDUpdate_WithAndWithout_UnitGain" has passed
|
||||
|
||||
Test case "MatrixUnitTests/MatrixNesterovAcceleratedMomentumSGDUpdate_WithAndWithout_UnitGain" has passed
|
||||
|
||||
Test case "MatrixUnitTests/MatrixFSAdagradUpdate_WithAndWithout_UnitGain" has passed
|
||||
|
||||
Test suite "QuantizersUnitTests" has passed with:
|
||||
2 test cases out of 2 passed
|
||||
12 assertions out of 12 passed
|
||||
|
|
|
@ -1182,6 +1182,218 @@ BOOST_FIXTURE_TEST_CASE(MatrixAssignNumOfDiff, RandomSeedFixture)
|
|||
BOOST_CHECK_EQUAL(expectedDiff, actual.Get00Element());
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixScale, RandomSeedFixture)
|
||||
{
|
||||
const float low = -1.0f;
|
||||
const float high = 1.0f;
|
||||
float alpha = 0.7713f;
|
||||
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
|
||||
{
|
||||
auto a1 = SingleMatrix::RandomUniform(7, 11, deviceId, low, high, IncrementCounter());
|
||||
auto a2 = a1.DeepClone();
|
||||
BOOST_ASSERT(a1.IsEqualTo(a2));
|
||||
|
||||
auto b1 = SingleMatrix::RandomUniform(7, 11, deviceId, low, high, IncrementCounter());
|
||||
auto b2 = b1.DeepClone();
|
||||
BOOST_ASSERT(b1.IsEqualTo(b2));
|
||||
|
||||
Matrix<float>::ScaleAndAdd(alpha, b1, a1);
|
||||
|
||||
Matrix<float>::Scale(alpha, b2);
|
||||
a2 += b2;
|
||||
|
||||
// BUGBUG: this test currently fails on GPU.
|
||||
if (deviceId != CPUDEVICE)
|
||||
continue;
|
||||
|
||||
BOOST_CHECK(a1.IsEqualTo(a2));
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixSGDUpdate, RandomSeedFixture)
|
||||
{
|
||||
const float low = -1.0f;
|
||||
const float high = 1.0f;
|
||||
float lr = 0.77f;
|
||||
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
|
||||
{
|
||||
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto p2 = p1.DeepClone();
|
||||
BOOST_ASSERT(p1.IsEqualTo(p2));
|
||||
|
||||
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto g2 = g1.DeepClone();
|
||||
BOOST_ASSERT(g1.IsEqualTo(g2));
|
||||
|
||||
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto sg2 = sg1.DeepClone();
|
||||
BOOST_ASSERT(sg1.IsEqualTo(sg2));
|
||||
|
||||
for (; lr > 0.01; lr = lr / 2)
|
||||
{
|
||||
if (deviceId != CPUDEVICE)
|
||||
{
|
||||
// g1 is modified inside the GPU version of SGDUpdate, restore the original value here.
|
||||
g1.SetValue(g2);
|
||||
}
|
||||
|
||||
p1.SGDUpdate(g1, lr);
|
||||
p2.MomentumSGDUpdate(g2, sg2, lr, 0.0);
|
||||
|
||||
BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
|
||||
if (deviceId != CPUDEVICE)
|
||||
continue;
|
||||
|
||||
// GPU version of SGDUpdate scales gradient by the learning rate, this check will fail.
|
||||
BOOST_CHECK(g1.IsEqualTo(g2));
|
||||
}
|
||||
|
||||
lr = std::pow(lr, lr);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixMomentumSGDUpdate_WithAndWithout_UnitGain, RandomSeedFixture)
|
||||
{
|
||||
const float low = -1.0f;
|
||||
const float high = 1.0f;
|
||||
float lr = 0.77f;
|
||||
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
|
||||
{
|
||||
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto p2 = p1.DeepClone();
|
||||
BOOST_ASSERT(p1.IsEqualTo(p2));
|
||||
|
||||
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto g2 = g1.DeepClone();
|
||||
BOOST_ASSERT(g1.IsEqualTo(g2));
|
||||
|
||||
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto sg2 = sg1.DeepClone();
|
||||
BOOST_ASSERT(sg1.IsEqualTo(sg2));
|
||||
|
||||
for (; lr > 0.01; lr = lr / 2)
|
||||
{
|
||||
p1.MomentumSGDUpdate(g1, sg1, lr, 0.0, true);
|
||||
p2.MomentumSGDUpdate(g2, sg2, lr, 0.0, false);
|
||||
BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
}
|
||||
|
||||
for (lr = 1.0; lr > 0.03; lr = lr / 2)
|
||||
{
|
||||
p1.MomentumSGDUpdate(g1, sg1, lr, 0.5, true);
|
||||
p2.MomentumSGDUpdate(g2, sg2, lr/2, 0.5, false);
|
||||
BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
}
|
||||
|
||||
BOOST_CHECK(g1.IsEqualTo(g2));
|
||||
BOOST_CHECK(sg1.IsEqualTo(sg2));
|
||||
|
||||
p1.MomentumSGDUpdate(g1, sg1, lr, 0.5, true);
|
||||
p2.MomentumSGDUpdate(g2, sg2, lr, 0.5, false);
|
||||
BOOST_CHECK(!p1.IsEqualTo(p2));
|
||||
|
||||
lr = std::pow(lr, lr);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixNesterovAcceleratedMomentumSGDUpdate_WithAndWithout_UnitGain, RandomSeedFixture)
|
||||
{
|
||||
const float low = -1.0f;
|
||||
const float high = 1.0f;
|
||||
float lr = 0.77f;
|
||||
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
|
||||
{
|
||||
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto p2 = p1.DeepClone();
|
||||
BOOST_ASSERT(p1.IsEqualTo(p2));
|
||||
|
||||
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto g2 = g1.DeepClone();
|
||||
BOOST_ASSERT(g1.IsEqualTo(g2));
|
||||
|
||||
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto sg2 = sg1.DeepClone();
|
||||
BOOST_ASSERT(sg1.IsEqualTo(sg2));
|
||||
|
||||
for (; lr > 0.01; lr = lr / 2)
|
||||
{
|
||||
p1.NesterovAcceleratedMomentumSGDUpdate(g1, sg1, lr, 0.0, true);
|
||||
p2.NesterovAcceleratedMomentumSGDUpdate(g2, sg2, lr, 0.0, false);
|
||||
BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
}
|
||||
|
||||
for (lr = 1.0; lr > 0.03; lr = lr / 2)
|
||||
{
|
||||
p1.NesterovAcceleratedMomentumSGDUpdate(g1, sg1, lr, 0.5, true);
|
||||
p2.NesterovAcceleratedMomentumSGDUpdate(g2, sg2, lr/2, 0.5, false);
|
||||
BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
}
|
||||
|
||||
BOOST_CHECK(g1.IsEqualTo(g2));
|
||||
BOOST_CHECK(sg1.IsEqualTo(sg2));
|
||||
|
||||
p1.NesterovAcceleratedMomentumSGDUpdate(g1, sg1, lr, 0.5, true);
|
||||
p2.NesterovAcceleratedMomentumSGDUpdate(g2, sg2, lr, 0.5, false);
|
||||
BOOST_CHECK(!p1.IsEqualTo(p2));
|
||||
|
||||
lr = std::pow(lr, lr);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixFSAdagradUpdate_WithAndWithout_UnitGain, RandomSeedFixture)
|
||||
{
|
||||
const float low = -1.0f;
|
||||
const float high = 1.0f;
|
||||
float lr = 0.77f;
|
||||
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
|
||||
{
|
||||
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto p2 = p1.DeepClone();
|
||||
BOOST_ASSERT(p1.IsEqualTo(p2));
|
||||
|
||||
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto g2 = g1.DeepClone();
|
||||
BOOST_ASSERT(g1.IsEqualTo(g2));
|
||||
|
||||
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
|
||||
auto sg2 = sg1.DeepClone();
|
||||
BOOST_ASSERT(sg1.IsEqualTo(sg2));
|
||||
|
||||
for (; lr > 0.01; lr = lr / 2)
|
||||
{
|
||||
size_t mbSize = 100;
|
||||
double smoothedCount = 10 / lr;
|
||||
double targetAdagradAvDenom = 1.0;
|
||||
double varMomentum = 1.0 - lr;
|
||||
|
||||
sg1.FSAdagradUpdate(mbSize, g1, p1, smoothedCount, lr, targetAdagradAvDenom, 0.0, varMomentum, true);
|
||||
sg2.FSAdagradUpdate(mbSize, g2, p2, smoothedCount, lr, targetAdagradAvDenom, 0.0, varMomentum, true /*false*/);
|
||||
// BUGBUG: at the moment this fails even with identical arguments.
|
||||
// BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
}
|
||||
|
||||
sg2.SetValue(sg1);
|
||||
BOOST_ASSERT(sg1.IsEqualTo(sg2));
|
||||
|
||||
for (lr = 1.0; lr > 0.03; lr = lr / 2)
|
||||
{
|
||||
size_t mbSize = 100;
|
||||
double smoothedCount = 10 / lr;
|
||||
double targetAdagradAvDenom = 1.0;
|
||||
double varMomentum = 1.0 - lr;
|
||||
|
||||
sg1.FSAdagradUpdate(mbSize, g1, p1, smoothedCount, lr, targetAdagradAvDenom, 0.5, varMomentum, true);
|
||||
sg2.FSAdagradUpdate(mbSize, g2, p2, smoothedCount, lr /*lr/2*/, targetAdagradAvDenom, 0.5, varMomentum, true /*false*/);
|
||||
// BUGBUG: at the moment this fails even with identical arguments.
|
||||
// BOOST_CHECK(p1.IsEqualTo(p2));
|
||||
}
|
||||
|
||||
lr = std::pow(lr, lr);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE_END()
|
||||
}
|
||||
} } }
|
||||
|
|
|
@ -56,24 +56,24 @@ void TestSGDLearner(size_t numParameters, size_t numMinibatches, const DeviceDes
|
|||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void TestMomentumSGDLearner(size_t numParameters, size_t numMinibatches, const DeviceDescriptor& device)
|
||||
void TestMomentumSGDLearner(size_t numParameters, size_t numMinibatches, bool unitGainMomentum, const DeviceDescriptor& device)
|
||||
{
|
||||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
LearningRatePerMinibatchSchedule learnigRateSchedule = { { 3.0, 2.0, 1.0 }, numMinibatches };
|
||||
MomentumPerSampleSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 };
|
||||
auto learner = MomentumSGDLearner(parameters, learnigRateSchedule, momentumValues);
|
||||
auto learner = MomentumSGDLearner(parameters, learnigRateSchedule, momentumValues, unitGainMomentum);
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
FloatingPointCompare(learner->LearningRate(), 2.0, "Learner::LearningRate does not match expectation");
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void TestNesterovLearner(size_t numParameters, size_t numMinibatches, const DeviceDescriptor& device)
|
||||
void TestNesterovLearner(size_t numParameters, size_t numMinibatches, bool unitGainMomentum, const DeviceDescriptor& device)
|
||||
{
|
||||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
MomentumAsTimeConstantSchedule momentumValues = { { { 1, 1 }, { 3, 5 }, { 10, 25 } }, 100 };
|
||||
auto learner = NesterovLearner(parameters, LearningRatePerMinibatchSchedule( { { 1, 0.5 }, { 10, 0.25 }, { 20, 0.125 } }, 3 ), momentumValues);
|
||||
auto learner = NesterovLearner(parameters, LearningRatePerMinibatchSchedule( { { 1, 0.5 }, { 10, 0.25 }, { 20, 0.125 } }, 3 ), momentumValues, unitGainMomentum);
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
}
|
||||
|
||||
|
@ -87,11 +87,11 @@ void TestAdaGradLearner(size_t numParameters, size_t numMinibatches, const Devic
|
|||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void TestFSAdaGradLearner(size_t numParameters, size_t numMinibatches, const DeviceDescriptor& device)
|
||||
void TestFSAdaGradLearner(size_t numParameters, size_t numMinibatches, bool unitGainMomentum, const DeviceDescriptor& device)
|
||||
{
|
||||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
auto learner = AdamLearner(parameters, LearningRatePerSampleSchedule({ 0.5 }), MomentumAsTimeConstantSchedule({ 10.0, 100.0, 1000.0 }));
|
||||
auto learner = AdamLearner(parameters, LearningRatePerSampleSchedule({ 0.5 }), MomentumAsTimeConstantSchedule({ 10.0, 100.0, 1000.0 }), unitGainMomentum);
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
}
|
||||
|
||||
|
@ -235,21 +235,33 @@ void LearnerTests()
|
|||
|
||||
TestTrainingParametersSchedule();
|
||||
|
||||
TestSGDLearner<double>(5, 3, DeviceDescriptor::CPUDevice());
|
||||
TestMomentumSGDLearner<float>(3, 4, DeviceDescriptor::CPUDevice());
|
||||
TestNesterovLearner<float>(1, 4, DeviceDescriptor::CPUDevice());
|
||||
TestAdaGradLearner<double>(2, 5, DeviceDescriptor::CPUDevice());
|
||||
TestFSAdaGradLearner<double>(10, 2, DeviceDescriptor::CPUDevice());
|
||||
TestRMSPropLearner<float>(3, 3, DeviceDescriptor::CPUDevice());
|
||||
vector<DeviceDescriptor> devices{DeviceDescriptor::CPUDevice()};
|
||||
|
||||
if (IsGPUAvailable())
|
||||
{
|
||||
TestSGDLearner<double>(1, 1, DeviceDescriptor::CPUDevice());
|
||||
TestMomentumSGDLearner<float>(3, 5, DeviceDescriptor::GPUDevice(0));
|
||||
TestNesterovLearner<float>(1, 4, DeviceDescriptor::GPUDevice(0));
|
||||
TestAdaGradLearner<double>(1, 2, DeviceDescriptor::GPUDevice(0));
|
||||
TestFSAdaGradLearner<double>(2, 2, DeviceDescriptor::GPUDevice(0));
|
||||
TestRMSPropLearner<float>(3, 3, DeviceDescriptor::GPUDevice(0));
|
||||
devices.push_back(DeviceDescriptor::GPUDevice(0));
|
||||
}
|
||||
|
||||
srand(1);
|
||||
|
||||
for (auto& device : devices)
|
||||
{
|
||||
auto numParameters = 1 + rand() % 5;
|
||||
auto numMinibatches = 1 + rand() % 5;
|
||||
TestSGDLearner<double>(numParameters, numMinibatches, device);
|
||||
TestAdaGradLearner<double>(numParameters, numMinibatches, device);
|
||||
TestRMSPropLearner<float>(numParameters, numMinibatches, device);
|
||||
}
|
||||
|
||||
for (auto& device : devices)
|
||||
{
|
||||
for (auto unitGain : { true, false })
|
||||
{
|
||||
auto numParameters = 1 + rand() % 5;
|
||||
auto numMinibatches = 1 + rand() % 5;
|
||||
TestMomentumSGDLearner<float>(numParameters, numMinibatches, unitGain, device);
|
||||
TestNesterovLearner<float>(numParameters, numMinibatches, unitGain, device);
|
||||
TestFSAdaGradLearner<double>(numParameters, numMinibatches, unitGain, device);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -180,7 +180,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
AdditionalLearningOptions additionalOptions;
|
||||
additionalOptions.gradientClippingThresholdPerSample = 2.3;
|
||||
additionalOptions.gradientClippingWithTruncation = true;
|
||||
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumTimeConstant, additionalOptions) });
|
||||
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true, additionalOptions) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
size_t minibatchSize1 = 72;
|
||||
|
|
|
@ -48,7 +48,9 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool useSparseLa
|
|||
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.0005;
|
||||
MomentumAsTimeConstantSchedule momentumTimeConstant = 256;
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant) });
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction,
|
||||
{ MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample,
|
||||
momentumTimeConstant, /*unitGainMomentum = */true) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
for (size_t i = 0; true; i++)
|
||||
|
|
|
@ -416,7 +416,7 @@ Trainer BuildTrainer(const FunctionPtr& function, const Variable& labels,
|
|||
{
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(function, labels, L"lossFunction");
|
||||
auto prediction = CNTK::ClassificationError(function, labels, L"classificationError");
|
||||
auto learner = MomentumSGDLearner(function->Parameters(), lr, m);
|
||||
auto learner = MomentumSGDLearner(function->Parameters(), lr, m, /*unitGainMomentum = */true);
|
||||
return Trainer(function, trainingLoss, prediction, { learner });
|
||||
}
|
||||
|
||||
|
@ -674,7 +674,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
|
|||
|
||||
LearningRatePerSampleSchedule learningRateSchedule2({ { 0.04, 0.02, 0.01, 0.008, 0.004, 0.002, 0.001 } }, actualMBSize);
|
||||
MomentumAsTimeConstantSchedule momentumSchedule({ { 900, 800, 700, 600, 500 } }, actualMBSize);
|
||||
auto learner2 = AdamLearner(classifierOutput->Parameters(), learningRateSchedule, momentumSchedule);
|
||||
auto learner2 = AdamLearner(classifierOutput->Parameters(), learningRateSchedule, momentumSchedule, /*unitGainMomentum = */true);
|
||||
Trainer trainer2(classifierOutput, trainingLoss, prediction, { learner });
|
||||
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
|
|||
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.000781;
|
||||
MomentumAsTimeConstantSchedule momentumTimeConstant = 6074;
|
||||
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant);
|
||||
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, {learner});
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
|
|
|
@ -271,7 +271,7 @@
|
|||
" # Feel free to try other optimizers from \n",
|
||||
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
|
||||
" learner = adam_sgd(model.parameters,\n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant) \n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True) \n",
|
||||
" \n",
|
||||
" # Instantiate the trainer\n",
|
||||
" trainer = Trainer(model, loss, label_error, learner)\n",
|
||||
|
|
|
@ -324,6 +324,7 @@
|
|||
" # trainer object\n",
|
||||
" learner = momentum_sgd(z.parameters, \n",
|
||||
" lr = lr_per_minibatch, momentum = momentum_time_constant, \n",
|
||||
" unit_gain = True, \n",
|
||||
" l2_regularization_weight=l2_reg_weight)\n",
|
||||
" trainer = Trainer(z, ce, pe, [learner])\n",
|
||||
"\n",
|
||||
|
|
|
@ -392,6 +392,7 @@
|
|||
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
|
||||
" learner = adam_sgd(criterion.parameters,\n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant,\n",
|
||||
" unit_gain = True, \n",
|
||||
" low_memory=True,\n",
|
||||
" gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)\n",
|
||||
"\n",
|
||||
|
@ -497,7 +498,7 @@
|
|||
" lr_schedule = learning_rate_schedule(1, UnitType.minibatch)\n",
|
||||
" momentum_as_time_constant = momentum_as_time_constant_schedule(0)\n",
|
||||
" dummy_learner = adam_sgd(criterion.parameters, \n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, low_memory=True)\n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True, low_memory=True)\n",
|
||||
" evaluator = Trainer(model, criterion.outputs[0], criterion.outputs[1], dummy_learner)\n",
|
||||
" progress_printer = ProgressPrinter(tag='Evaluation')\n",
|
||||
"\n",
|
||||
|
|
|
@ -753,6 +753,7 @@
|
|||
"gradient_clipping_with_truncation = True\n",
|
||||
"learner = momentum_sgd(model.parameters,\n",
|
||||
" lr_per_sample, momentum_time_constant,\n",
|
||||
" unit_gain = True, \n",
|
||||
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
|
||||
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
|
||||
"trainer = Trainer(model, ce, errs, learner)"
|
||||
|
@ -919,6 +920,7 @@
|
|||
" gradient_clipping_with_truncation = True\n",
|
||||
" learner = momentum_sgd(model.parameters,\n",
|
||||
" lr_per_sample, momentum_time_constant,\n",
|
||||
" unit_gain = True, \n",
|
||||
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
|
||||
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
|
||||
" trainer = Trainer(model, ce, errs, learner)\n",
|
||||
|
|
|
@ -309,15 +309,18 @@
|
|||
%ignore CNTK::MomentumSGDLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
%ignore CNTK::NesterovLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
%ignore CNTK::DefaultVarianceMomentum;
|
||||
%ignore CNTK::AdamLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
|
||||
bool lowMemory = true,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
|
|
@ -343,7 +343,7 @@ def sgd(parameters, lr,
|
|||
return cntk_py.sgd_learner(parameters, lr, additional_options)
|
||||
|
||||
@typemap
|
||||
def momentum_sgd(parameters, lr, momentum,
|
||||
def momentum_sgd(parameters, lr, momentum, unit_gain,
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
|
||||
gradient_clipping_with_truncation=True):
|
||||
|
@ -358,6 +358,7 @@ def momentum_sgd(parameters, lr, momentum,
|
|||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -385,11 +386,11 @@ def momentum_sgd(parameters, lr, momentum,
|
|||
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
|
||||
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
|
||||
|
||||
return cntk_py.momentum_sgd_learner(parameters, lr, momentum,
|
||||
return cntk_py.momentum_sgd_learner(parameters, lr, momentum, unit_gain,
|
||||
additional_options)
|
||||
|
||||
@typemap
|
||||
def nesterov(parameters, lr, momentum,
|
||||
def nesterov(parameters, lr, momentum, unit_gain,
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
|
||||
gradient_clipping_with_truncation=True):
|
||||
|
@ -406,6 +407,7 @@ def nesterov(parameters, lr, momentum,
|
|||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -442,7 +444,7 @@ def nesterov(parameters, lr, momentum,
|
|||
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
|
||||
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
|
||||
|
||||
return cntk_py.nesterov_learner(parameters, lr, momentum,
|
||||
return cntk_py.nesterov_learner(parameters, lr, momentum, unit_gain,
|
||||
additional_options)
|
||||
|
||||
@typemap
|
||||
|
@ -495,7 +497,7 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
|
|||
|
||||
# TODO: unCamelCase and integrate upcoming CR
|
||||
@typemap
|
||||
def adam_sgd(parameters, lr, momentum,
|
||||
def adam_sgd(parameters, lr, momentum, unit_gain,
|
||||
variance_momentum = momentum_as_time_constant_schedule(720000),
|
||||
low_memory=True,
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
|
@ -513,6 +515,7 @@ def adam_sgd(parameters, lr, momentum,
|
|||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
|
||||
variance_momentum (output of :func:`momentum_schedule` or
|
||||
:func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
|
||||
to ``momentum_as_time_constant_schedule(720000)``.
|
||||
|
@ -551,7 +554,7 @@ def adam_sgd(parameters, lr, momentum,
|
|||
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
|
||||
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
|
||||
|
||||
return cntk_py.adam_learner(parameters, lr, momentum,
|
||||
return cntk_py.adam_learner(parameters, lr, momentum, unit_gain,
|
||||
variance_momentum, low_memory, additional_options)
|
||||
|
||||
@typemap
|
||||
|
|
|
@ -45,7 +45,7 @@ def run_distributed_training(tmpdir, create_func):
|
|||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
dist_learner = create_func(momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant))
|
||||
dist_learner = create_func(momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, True))
|
||||
|
||||
communicator = dist_learner.communicator()
|
||||
workers = communicator.workers()
|
||||
|
|
|
@ -73,16 +73,16 @@ def test_learner_init():
|
|||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.1, UnitType.sample)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, True)
|
||||
|
||||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample)
|
||||
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
|
||||
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=False)
|
||||
|
||||
lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample)
|
||||
adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True)
|
||||
|
||||
lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample)
|
||||
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
|
||||
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=True)
|
||||
|
||||
gamma, inc, dec, max, min = [0.1]*5
|
||||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)
|
||||
|
|
|
@ -29,7 +29,7 @@ def test_trainer(tmpdir):
|
|||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
trainer = Trainer(z, ce, errs,
|
||||
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)])
|
||||
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, True)])
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
arguments = {in1: in1_value, labels: label_value}
|
||||
|
@ -57,7 +57,7 @@ def test_output_to_retain():
|
|||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
trainer = Trainer(z, ce, errs,
|
||||
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)])
|
||||
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, True)])
|
||||
in1_value = [[[1]], [[2]]]
|
||||
label_value = [[0], [1]]
|
||||
arguments = {in1: in1_value, labels: label_value}
|
||||
|
|
Загрузка…
Ссылка в новой задаче