This commit contains a breaking change: it adds a flag to both
  C++ and python Learner API that indicates if the momentum should be applied  in
  the regular fashion or as a unit-gain filter. This flag is a required
  parameter for learners that use momentum (momentum_sgd, nesterov_sgd and
  adam_sgd).
This commit is contained in:
Alexey Reznichenko 2017-01-11 09:18:42 +01:00
Родитель 43b30a1c41
Коммит 3ab246855b
43 изменённых файлов: 688 добавлений и 238 удалений

Просмотреть файл

@ -73,7 +73,9 @@ def convnet_cifar10(debug_output=False):
l2_reg_weight = 0.002 l2_reg_weight = 0.002
# Instantiate the trainer object to drive the model training # Instantiate the trainer object to drive the model training
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, l2_regularization_weight = l2_reg_weight) learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight = l2_reg_weight)
trainer = cntk.Trainer(z, ce, pe, learner) trainer = cntk.Trainer(z, ce, pe, learner)
# define mapping from reader streams to network inputs # define mapping from reader streams to network inputs

Просмотреть файл

@ -85,6 +85,7 @@ def convnet_cifar10_dataaug(reader_train, reader_test, max_epochs = 80):
# trainer object # trainer object
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight = l2_reg_weight) l2_regularization_weight = l2_reg_weight)
trainer = cntk.Trainer(z, ce, pe, learner) trainer = cntk.Trainer(z, ce, pe, learner)

Просмотреть файл

@ -93,7 +93,9 @@ def convnet_cifar10_dataaug(create_train_reader, test_reader, create_dist_learne
# trainer object # trainer object
learner = create_dist_learner( learner = create_dist_learner(
cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight)) cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight=l2_reg_weight))
trainer = cntk.Trainer(z, ce, pe, learner) trainer = cntk.Trainer(z, ce, pe, learner)

Просмотреть файл

@ -64,7 +64,7 @@ def convnet_mnist(debug_output=False):
mm_schedule = cntk.learner.momentum_as_time_constant_schedule(mm_time_constant, epoch_size) mm_schedule = cntk.learner.momentum_as_time_constant_schedule(mm_time_constant, epoch_size)
# Instantiate the trainer object to drive the model training # Instantiate the trainer object to drive the model training
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule) learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, unit_gain=True)
trainer = cntk.Trainer(z, ce, pe, learner) trainer = cntk.Trainer(z, ce, pe, learner)
# define mapping from reader streams to network inputs # define mapping from reader streams to network inputs

Просмотреть файл

@ -88,6 +88,7 @@ def train_and_evaluate(reader_train, reader_test, network_name, max_epochs):
# trainer object # trainer object
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule, learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight = l2_reg_weight) l2_regularization_weight = l2_reg_weight)
trainer = Trainer(z, ce, pe, learner) trainer = Trainer(z, ce, pe, learner)

Просмотреть файл

@ -98,6 +98,7 @@ def train_and_evaluate(create_train_reader, test_reader, network_name, max_epoch
# trainer object # trainer object
learner = create_dist_learner(momentum_sgd(z.parameters, lr_schedule, mm_schedule, learner = create_dist_learner(momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight = l2_reg_weight)) l2_regularization_weight = l2_reg_weight))
trainer = Trainer(z, ce, pe, learner) trainer = Trainer(z, ce, pe, learner)

Просмотреть файл

@ -83,6 +83,7 @@ def train(reader, model, max_epochs):
lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size) lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size)
learner = adam_sgd(z.parameters, learner = adam_sgd(z.parameters,
lr=lr_per_sample, momentum=momentum_time_constant, lr=lr_per_sample, momentum=momentum_time_constant,
unit_gain=True,
low_memory=True, low_memory=True,
gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True) gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)

Просмотреть файл

@ -157,7 +157,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
clipping_threshold_per_sample = 2.3 clipping_threshold_per_sample = 2.3
gradient_clipping_with_truncation = True gradient_clipping_with_truncation = True
learner = momentum_sgd(z.parameters, learner = momentum_sgd(z.parameters,
lr_per_minibatch, momentum_time_constant, lr_per_minibatch, momentum_time_constant,
unit_gain=True,
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample, gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
gradient_clipping_with_truncation=gradient_clipping_with_truncation) gradient_clipping_with_truncation=gradient_clipping_with_truncation)
trainer = Trainer(z, ce, errs, learner) trainer = Trainer(z, ce, errs, learner)
@ -241,7 +242,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
ce = cross_entropy_with_softmax(z, label_sequence) ce = cross_entropy_with_softmax(z, label_sequence)
errs = classification_error(z, label_sequence) errs = classification_error(z, label_sequence)
trainer = Trainer(z, ce, errs, [momentum_sgd( trainer = Trainer(z, ce, errs, [momentum_sgd(
z.parameters, lr_per_minibatch, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)]) z.parameters, lr_per_minibatch, momentum_time_constant, True,
clipping_threshold_per_sample, gradient_clipping_with_truncation)])
error2 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim) error2 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim)

Просмотреть файл

@ -164,6 +164,7 @@ def train_lm(training_file):
clipping_threshold_per_sample = 5.0 clipping_threshold_per_sample = 5.0
gradient_clipping_with_truncation = True gradient_clipping_with_truncation = True
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant,
unit_gain=True,
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample, gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
gradient_clipping_with_truncation=gradient_clipping_with_truncation) gradient_clipping_with_truncation=gradient_clipping_with_truncation)
trainer = Trainer(z, ce, errs, learner) trainer = Trainer(z, ce, errs, learner)

Просмотреть файл

@ -198,7 +198,7 @@ def conv3d_ucf11(train_reader, test_reader, max_epochs=30):
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant, epoch_size=epoch_size) mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant, epoch_size=epoch_size)
# Instantiate the trainer object to drive the model training # Instantiate the trainer object to drive the model training
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule) learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule, True)
trainer = Trainer(z, ce, pe, learner) trainer = Trainer(z, ce, pe, learner)
log_number_of_parameters(z) ; print() log_number_of_parameters(z) ; print()

Просмотреть файл

@ -3531,6 +3531,7 @@ namespace CNTK
CNTK_API LearnerPtr MomentumSGDLearner(const std::vector<Parameter>& parameters, CNTK_API LearnerPtr MomentumSGDLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
/// ///
@ -3539,6 +3540,7 @@ namespace CNTK
CNTK_API LearnerPtr NesterovLearner(const std::vector<Parameter>& parameters, CNTK_API LearnerPtr NesterovLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
static MomentumSchedule DefaultVarianceMomentum = MomentumAsTimeConstantSchedule(2 * 3600 * 100); static MomentumSchedule DefaultVarianceMomentum = MomentumAsTimeConstantSchedule(2 * 3600 * 100);
@ -3549,6 +3551,7 @@ namespace CNTK
CNTK_API LearnerPtr AdamLearner(const std::vector<Parameter>& parameters, CNTK_API LearnerPtr AdamLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum, const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
bool lowMemory = true, bool lowMemory = true,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());

Просмотреть файл

@ -9,7 +9,7 @@
#include "Utils.h" #include "Utils.h"
#include "Serialization.h" #include "Serialization.h"
#define UPDATE_FUNCTION \ #define DISPATCH_TO_TYPED_UPDATE_FUNCTION \
switch (smoothedGradientValue->GetDataType()) \ switch (smoothedGradientValue->GetDataType()) \
{ \ { \
case DataType::Float: \ case DataType::Float: \
@ -22,6 +22,11 @@
NOT_IMPLEMENTED; \ NOT_IMPLEMENTED; \
} }
#define GET_WRITABLE_MATRICES \
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue); \
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue); \
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameter.Value());
using namespace Microsoft::MSR::CNTK; using namespace Microsoft::MSR::CNTK;
using namespace std; using namespace std;
@ -184,15 +189,13 @@ namespace CNTK
LogicError("Learner parameters contain duplicates."); LogicError("Learner parameters contain duplicates.");
} }
for (const auto& parameter : parameters) if (allocateSmoothGradients)
{ {
if (!allocateSmoothGradients) for (const auto& parameter : parameters)
{ {
continue; NDArrayViewPtr view = AllocateNDArrayView(parameter, parameter.Shape());
m_smoothedGradientValues.emplace(parameter, view);
} }
NDArrayViewPtr view = AllocateNDArrayView(parameter, parameter.Shape());
m_smoothedGradientValues.insert(make_pair(parameter, view));
} }
} }
@ -256,7 +259,7 @@ namespace CNTK
Print(gradientValue, "Gradient Update"); Print(gradientValue, "Gradient Update");
Print(smoothedGradientValue, "Smoothed Gradient Input"); Print(smoothedGradientValue, "Smoothed Gradient Input");
#endif #endif
UPDATE_FUNCTION; DISPATCH_TO_TYPED_UPDATE_FUNCTION;
#if DUMPOUTPUT #if DUMPOUTPUT
Print(parameter.Value(), "Parameter Update"); Print(parameter.Value(), "Parameter Update");
@ -275,7 +278,8 @@ namespace CNTK
} }
template <typename ElementType> template <typename ElementType>
void LearnerBase::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const void LearnerBase::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{ {
const auto& parameterValue = parameter.Value(); const auto& parameterValue = parameter.Value();
PreProcess<ElementType>(parameterValue, gradientValue, trainingSampleCount); PreProcess<ElementType>(parameterValue, gradientValue, trainingSampleCount);
@ -364,27 +368,39 @@ namespace CNTK
} }
} }
/*virtual*/ void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/ LearnerSGD::LearnerSGD(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
AdditionalLearningOptions additionalOptions,
bool allocateSmoothGradients)
: LearnerBase(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients)
{ {
UPDATE_FUNCTION; if (!allocateSmoothGradients)
{
// the vanilla sgd does not need the smooth gradients per se,
// insert dummy nd views instead.
for (const auto& parameter : parameters)
{
m_smoothedGradientValues.emplace(parameter, AllocateNDArrayView(parameter, {}));
}
}
}
/*virtual*/ void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
{
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
} }
template <typename ElementType> template <typename ElementType>
void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const void LearnerSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{ {
const auto& parameterValue = parameter.Value(); UNUSED(smoothedGradientValue);
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue); const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue); const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameter.Value());
const auto learningRate = ElementType(LearningRate(trainingSampleCount)); const auto learningRate = ElementType(LearningRate(trainingSampleCount));
const auto momentum = ElementType(MomentumValueForMB(trainingSampleCount));
// TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters parameterMatrix->SGDUpdate(*gradientMatrix, learningRate);
// Also, come up with a better name for NormalGrad (Default? Regular? Plain?).
// (one for vanilla SGD, the other for momentum SGD, and the third one for NAG).
smoothedGradientMatrix->NormalGrad(*gradientMatrix, *parameterMatrix,
learningRate, momentum, UseNesterovMomentum());
} }
double LearnerMomentumSGD::MomentumValueForMB(const MomentumSchedule& schedule, size_t minibatchSize) const double LearnerMomentumSGD::MomentumValueForMB(const MomentumSchedule& schedule, size_t minibatchSize) const
@ -397,6 +413,44 @@ namespace CNTK
return std::pow(currentMomentum, minibatchSize); return std::pow(currentMomentum, minibatchSize);
} }
/*virtual*/ void LearnerMomentumSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
{
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
}
template <typename ElementType>
void LearnerMomentumSGD::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{
GET_WRITABLE_MATRICES;
const auto learningRate = ElementType(LearningRate(trainingSampleCount));
const auto momentum = ElementType(MomentumValueForMB(trainingSampleCount));
parameterMatrix->MomentumSGDUpdate(*gradientMatrix, *smoothedGradientMatrix,
learningRate, momentum, UseUnitGainMomentum());
}
/*virtual*/ void LearnerNesterov::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
{
DISPATCH_TO_TYPED_UPDATE_FUNCTION;
}
template <typename ElementType>
void LearnerNesterov::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{
GET_WRITABLE_MATRICES;
const auto learningRate = ElementType(LearningRate(trainingSampleCount));
const auto momentum = ElementType(MomentumValueForMB(trainingSampleCount));
parameterMatrix->NesterovAcceleratedMomentumSGDUpdate(*gradientMatrix, *smoothedGradientMatrix,
learningRate, momentum, UseUnitGainMomentum());
}
LearnerAdaGrad::LearnerAdaGrad(const std::vector<Parameter>& parameters, LearnerAdaGrad::LearnerAdaGrad(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
bool needAveMultiplier, bool needAveMultiplier,
@ -416,24 +470,21 @@ namespace CNTK
const auto shape = GetMatrixShape(parameter); const auto shape = GetMatrixShape(parameter);
NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], factor * shape[1] }); NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], factor * shape[1] });
m_smoothedGradientValues.insert(make_pair(parameter, view)); m_smoothedGradientValues.emplace(parameter, view);
} }
} }
/*virtual*/ void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/ /*virtual*/ void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
{ {
UPDATE_FUNCTION; DISPATCH_TO_TYPED_UPDATE_FUNCTION;
} }
template <typename ElementType> template <typename ElementType>
void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const void LearnerAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{ {
UNUSED(trainingSampleCount); GET_WRITABLE_MATRICES
const auto& parameterValue = parameter.Value();
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
const auto learningRate = LearningRate(trainingSampleCount); const auto learningRate = LearningRate(trainingSampleCount);
@ -446,32 +497,33 @@ namespace CNTK
LearnerFSAdaGrad::LearnerFSAdaGrad(const vector<Parameter>& parameters, LearnerFSAdaGrad::LearnerFSAdaGrad(const vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
const MomentumSchedule& varianceMomentumSchedule, const MomentumSchedule& varianceMomentumSchedule,
AdditionalLearningOptions additionalOptions) AdditionalLearningOptions additionalOptions)
: LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule, additionalOptions, /*allocateSmoothGradients*/ false), : LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule,
unitGain, additionalOptions, /*allocateSmoothGradients*/ false),
m_varianceMomentumSchedule(varianceMomentumSchedule) m_varianceMomentumSchedule(varianceMomentumSchedule)
{ {
for (const auto& parameter : parameters) for (const auto& parameter : parameters)
{ {
const auto shape = GetMatrixShape(parameter); const auto shape = GetMatrixShape(parameter);
NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], 2 * shape[1] }); NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], 2 * shape[1] });
m_smoothedGradientValues.insert(make_pair(parameter, view)); m_smoothedGradientValues.emplace(parameter, view);
m_smoothedCounts.insert(make_pair(parameter, 0.0)); m_smoothedCounts.emplace(parameter, 0.0);
} }
} }
/*virtual*/ void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/ /*virtual*/ void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
{ {
UPDATE_FUNCTION; DISPATCH_TO_TYPED_UPDATE_FUNCTION;
} }
template <typename ElementType> template <typename ElementType>
void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const void LearnerFSAdaGrad::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{ {
const auto& parameterValue = parameter.Value(); GET_WRITABLE_MATRICES;
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
const auto learningRate = LearningRate(trainingSampleCount); const auto learningRate = LearningRate(trainingSampleCount);
const auto momentum = MomentumValueForMB(trainingSampleCount); const auto momentum = MomentumValueForMB(trainingSampleCount);
@ -480,7 +532,8 @@ namespace CNTK
double& smoothedCount = m_smoothedCounts.at(parameter); double& smoothedCount = m_smoothedCounts.at(parameter);
smoothedGradientMatrix->FSAdagradUpdate(trainingSampleCount, *gradientMatrix, *parameterMatrix, smoothedCount, learningRate, s_targetAdagradAvDenom, momentum, varMomentum); smoothedGradientMatrix->FSAdagradUpdate(trainingSampleCount, *gradientMatrix, *parameterMatrix, smoothedCount, learningRate,
s_targetAdagradAvDenom, momentum, varMomentum, UseUnitGainMomentum());
} }
LearnerRMSProp::LearnerRMSProp(const vector<Parameter>& parameters, LearnerRMSProp::LearnerRMSProp(const vector<Parameter>& parameters,
@ -503,24 +556,21 @@ namespace CNTK
const auto shape = GetMatrixShape(parameter); const auto shape = GetMatrixShape(parameter);
NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], factor * shape[1] }); NDArrayViewPtr view = AllocateNDArrayView(parameter, { shape[0], factor * shape[1] });
m_smoothedGradientValues.insert(make_pair(parameter, view)); m_smoothedGradientValues.emplace(parameter, view);
} }
} }
/*virtual*/ void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/ /*virtual*/ void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const /*override*/
{ {
UPDATE_FUNCTION; DISPATCH_TO_TYPED_UPDATE_FUNCTION;
} }
template <typename ElementType> template <typename ElementType>
void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const void LearnerRMSProp::Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue,
const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const
{ {
UNUSED(trainingSampleCount); GET_WRITABLE_MATRICES;
const auto& parameterValue = parameter.Value();
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue);
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
const auto learningRate = LearningRate(trainingSampleCount); const auto learningRate = LearningRate(trainingSampleCount);
@ -548,22 +598,25 @@ namespace CNTK
LearnerPtr MomentumSGDLearner(const vector<Parameter>& parameters, LearnerPtr MomentumSGDLearner(const vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/) AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
{ {
return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRateSchedule, momentumSchedule, additionalOptions); return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRateSchedule, momentumSchedule, unitGain, additionalOptions);
} }
LearnerPtr NesterovLearner(const vector<Parameter>& parameters, LearnerPtr NesterovLearner(const vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/) AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
{ {
return MakeSharedObject<LearnerNesterov>(parameters, learningRateSchedule, momentumSchedule, additionalOptions); return MakeSharedObject<LearnerNesterov>(parameters, learningRateSchedule, momentumSchedule, unitGain, additionalOptions);
} }
LearnerPtr AdamLearner(const vector<Parameter>& parameters, LearnerPtr AdamLearner(const vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
const MomentumSchedule& varianceMomentumSchedule, /*= MomentumAsTimeConstantSchedulePerSample(2 * 3600 * 100)*/ const MomentumSchedule& varianceMomentumSchedule, /*= MomentumAsTimeConstantSchedulePerSample(2 * 3600 * 100)*/
bool lowMemory, /*= true*/ bool lowMemory, /*= true*/
AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/) AdditionalLearningOptions additionalOptions /*= AdditionalLearningOptions()*/)
@ -572,7 +625,7 @@ namespace CNTK
{ {
LogicError("AdamLearner: only the low-memory variant is supported at the moment."); LogicError("AdamLearner: only the low-memory variant is supported at the moment.");
} }
return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRateSchedule, momentumSchedule, varianceMomentumSchedule, additionalOptions); return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRateSchedule, momentumSchedule, unitGain, varianceMomentumSchedule, additionalOptions);
} }
LearnerPtr AdaGradLearner(const vector<Parameter>& parameters, LearnerPtr AdaGradLearner(const vector<Parameter>& parameters,

Просмотреть файл

@ -108,26 +108,13 @@ namespace CNTK
}; };
// Vanilla gradient descent optimization algorithm. // Vanilla gradient descent optimization algorithm.
class LearnerSGD : public LearnerBase class LearnerSGD final : public LearnerBase
{ {
public: public:
LearnerSGD(const std::vector<Parameter>& parameters, LearnerSGD(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
AdditionalLearningOptions additionalOptions, AdditionalLearningOptions additionalOptions,
bool allocateSmoothGradients = true) bool allocateSmoothGradients = false);
: LearnerBase(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients)
{}
// TODO: get rid of this as soon as NormalGrad is refactored.
virtual double MomentumValueForMB(size_t /*minibatchSize*/) const
{
return 0.0;
}
virtual bool UseNesterovMomentum() const
{
return false;
}
protected: protected:
@ -138,30 +125,45 @@ namespace CNTK
}; };
// SGD optimization with momentum. // SGD optimization with momentum.
class LearnerMomentumSGD : public LearnerSGD class LearnerMomentumSGD : public LearnerBase
{ {
public: public:
LearnerMomentumSGD(const std::vector<Parameter>& parameters, LearnerMomentumSGD(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions, AdditionalLearningOptions additionalOptions,
bool allocateSmoothGradients = true) bool allocateSmoothGradients = true)
: LearnerSGD(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients), : LearnerBase(parameters, learningRateSchedule, additionalOptions, allocateSmoothGradients),
m_momentumSchedule(momentumSchedule) m_momentumSchedule(momentumSchedule),
m_unitGain(unitGain)
{ } { }
// returns current per-minibatch momentum value. // returns current per-minibatch momentum value.
virtual double MomentumValueForMB(size_t minibatchSize) const override virtual double MomentumValueForMB(size_t minibatchSize) const
{ {
return MomentumValueForMB(m_momentumSchedule, minibatchSize); return MomentumValueForMB(m_momentumSchedule, minibatchSize);
} }
protected: protected:
virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const override;
template <typename ElementType>
void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const;
// returns current per-minibatch momentum value from the provided schedule. // returns current per-minibatch momentum value from the provided schedule.
double MomentumValueForMB(const MomentumSchedule& schedule, size_t minibatchSize) const; double MomentumValueForMB(const MomentumSchedule& schedule, size_t minibatchSize) const;
// Return true if the update should use classic momentum and
// false if the unit-gain momentum should be used instead.
bool UseUnitGainMomentum() const
{
return m_unitGain;
}
private: private:
MomentumSchedule m_momentumSchedule; MomentumSchedule m_momentumSchedule;
bool m_unitGain;
}; };
// Nesterov's accelerated SGDLearnerBase descent. // Nesterov's accelerated SGDLearnerBase descent.
@ -172,14 +174,16 @@ namespace CNTK
LearnerNesterov(const std::vector<Parameter>& parameters, LearnerNesterov(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions) AdditionalLearningOptions additionalOptions)
: LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule, additionalOptions, /*allocateSmoothGradients*/ true) : LearnerMomentumSGD(parameters, learningRateSchedule, momentumSchedule, unitGain, additionalOptions, /*allocateSmoothGradients*/ true)
{} {}
virtual bool UseNesterovMomentum() const override protected:
{ virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const override;
return true;
} template <typename ElementType>
void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const;
}; };
class LearnerAdaGrad : public LearnerBase class LearnerAdaGrad : public LearnerBase
@ -206,6 +210,7 @@ namespace CNTK
LearnerFSAdaGrad(const std::vector<Parameter>& parameters, LearnerFSAdaGrad(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
const MomentumSchedule& varianceMomentumSchedule, const MomentumSchedule& varianceMomentumSchedule,
AdditionalLearningOptions additionalOptions); AdditionalLearningOptions additionalOptions);

Просмотреть файл

@ -1198,8 +1198,11 @@ void CPUMatrix<ElemType>::FSAdagrad(CPUMatrix<ElemType>& gradients,
ElemType learnRatePerSample, ElemType learnRatePerSample,
ElemType momentum, ElemType momentum,
ElemType adaWeight, ElemType adaWeight,
ElemType adaMul) ElemType adaMul,
bool unitGainMomentum)
{ {
auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
size_t numColsNeeded = 2 * gradients.GetNumCols(); size_t numColsNeeded = 2 * gradients.GetNumCols();
if (IsEmpty() || (GetNumCols() < numColsNeeded)) if (IsEmpty() || (GetNumCols() < numColsNeeded))
@ -1234,7 +1237,7 @@ void CPUMatrix<ElemType>::FSAdagrad(CPUMatrix<ElemType>& gradients,
if (momentum > 0.0f) if (momentum > 0.0f)
{ {
g = momentum * smoothMom[i] + (1.0f - momentum) * g; g = momentum * smoothMom[i] + unitGainFactor * g;
smoothMom[i] = g; smoothMom[i] = g;
} }

Просмотреть файл

@ -92,7 +92,10 @@ public:
CPUMatrix<ElemType> Diagonal() const; CPUMatrix<ElemType> Diagonal() const;
ElemType Adagrad(CPUMatrix<ElemType>& gradients, const bool needAveMultiplier); ElemType Adagrad(CPUMatrix<ElemType>& gradients, const bool needAveMultiplier);
void FSAdagrad(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul);
void FSAdagrad(CPUMatrix<ElemType>& gradients, CPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample,
ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
ElemType RmsProp(CPUMatrix<ElemType>& gradients, ElemType RmsProp(CPUMatrix<ElemType>& gradients,
ElemType RMS_GAMMA, ElemType RMS_GAMMA,
ElemType RMS_WGT_INC, ElemType RMS_WGT_INC,

Просмотреть файл

@ -1126,11 +1126,19 @@ template <class ElemType>
return result; return result;
} }
// normal update for smoothed gradients c and current gradients (this) // A helper method used in MomentumSGDUpdate and NesterovAcceleratedMomentumSGDUpdate.
// TODO: comment seems wrong; cf. SGD.cpp: smoothedGradient.NormalGrad(gradientValues, functionValues,...) // Modifies the smoothed gradients "c", as well as the current gradients "this" on which this method is invoked.
// Classic momentum (unitGainFactor == 1.0):
// 1) c = momentum * c + this
// Unit-gain momentum (unitGainFactor == 1.0 - momentum):
// 1) c = momentum * c + (1.0 - momentum) * this
// 2) this = c
// TODO: NormalGrad is a misnomer here. Come up with a better name.
template <class ElemType> template <class ElemType>
void CPUSparseMatrix<ElemType>::NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum) void CPUSparseMatrix<ElemType>::NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum)
{ {
const auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
if (c.IsEmpty()) if (c.IsEmpty())
{ {
c.RequireSize(GetNumRows(), GetNumCols()); c.RequireSize(GetNumRows(), GetNumCols());
@ -1140,17 +1148,18 @@ void CPUSparseMatrix<ElemType>::NormalGrad(CPUMatrix<ElemType>& c, const ElemTyp
if (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol || GetFormat() == MatrixFormat::matrixFormatSparseBlockRow) if (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol || GetFormat() == MatrixFormat::matrixFormatSparseBlockRow)
{ {
const auto isSparseBlockCol = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol);
for (size_t j = 0; j < GetBlockSize(); j++) for (size_t j = 0; j < GetBlockSize(); j++)
{ {
size_t i = GetBlockIds()[j] - GetBlockIdShift(); size_t i = GetBlockIds()[j] - GetBlockIdShift();
size_t len = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) ? GetNumRows() : GetNumCols(); size_t len = (isSparseBlockCol) ? GetNumRows() : GetNumCols();
size_t start = j * len; size_t start = j * len;
for (size_t p = start; p < start + len; p++) for (size_t p = start; p < start + len; p++)
{ {
ElemType val = Buffer()[p]; ElemType val = Buffer()[p];
size_t row = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) ? (p - start) : i; size_t row = (isSparseBlockCol) ? (p - start) : i;
size_t col = (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) ? i : (p - start); size_t col = (isSparseBlockCol) ? i : (p - start);
c(row, col) = (1 - momentum) * val + momentum * c(row, col); c(row, col) = unitGainFactor * val + momentum * c(row, col);
Buffer()[p] = c(row, col); Buffer()[p] = c(row, col);
} }
} }

Просмотреть файл

@ -217,7 +217,7 @@ public:
} }
public: public:
void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum); void NormalGrad(CPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum = true);
ElemType Adagrad(CPUMatrix<ElemType>& c, const bool needAveMultiplier); ElemType Adagrad(CPUMatrix<ElemType>& c, const bool needAveMultiplier);
public: public:

Просмотреть файл

@ -1394,7 +1394,8 @@ void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients,
ElemType learnRatePerSample, ElemType learnRatePerSample,
ElemType momentum, ElemType momentum,
ElemType adaWeight, ElemType adaWeight,
ElemType adaMul) ElemType adaMul,
bool unitGainMomentum)
{ {
size_t numColsNeeded = 2 * gradients.GetNumCols(); size_t numColsNeeded = 2 * gradients.GetNumCols();
@ -1409,7 +1410,7 @@ void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients,
size_t n = gradients.GetNumElements(); size_t n = gradients.GetNumElements();
int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock;
_fsadagrad<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock>>>(n, gradients.Data(), Data(), Data()+ n, functionValues.Data(), _fsadagrad<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock>>>(n, gradients.Data(), Data(), Data()+ n, functionValues.Data(),
learnRatePerSample, momentum, adaWeight, adaMul); learnRatePerSample, momentum, adaWeight, adaMul, unitGainMomentum);
} }
template <class ElemType> template <class ElemType>

Просмотреть файл

@ -224,8 +224,17 @@ public:
} }
ElemType Adagrad(GPUMatrix<ElemType>& gradients, const bool needAveMultiplier); ElemType Adagrad(GPUMatrix<ElemType>& gradients, const bool needAveMultiplier);
void FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul);
ElemType RmsProp(GPUMatrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier); void FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample,
ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
ElemType RmsProp(GPUMatrix<ElemType>& gradients,
ElemType RMS_GAMMA,
ElemType RMS_WGT_INC,
ElemType RMS_WGT_MAX,
ElemType RMS_WGT_DEC,
ElemType RMS_WGT_MIN,
const bool needAveMultiplier);
void Reshape(const size_t numRows, const size_t numCols); void Reshape(const size_t numRows, const size_t numCols);

Просмотреть файл

@ -1421,8 +1421,9 @@ __global__ void _adagrad4BlockSparse(
template <class ElemType> template <class ElemType>
__global__ void _fsadagrad(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, ElemType* smoothMom, ElemType* val, __global__ void _fsadagrad(CUDA_LONG size, ElemType* grad, ElemType* smoothAda, ElemType* smoothMom, ElemType* val,
ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul) ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum)
{ {
const ElemType unitGainFactor = unitGainMomentum ? (1.0 - mom) : 1.0;
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x; CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_LONG stride = blockDim.x * gridDim.x; CUDA_LONG stride = blockDim.x * gridDim.x;
for (; idx < size; idx += stride) for (; idx < size; idx += stride)
@ -1449,7 +1450,7 @@ __global__ void _fsadagrad(CUDA_LONG size, ElemType* grad, ElemType* smoothAda,
if (mom > 0.0f) if (mom > 0.0f)
{ {
g = mom * smoothMom[idx] + (1.0f - mom) * g; g = mom * smoothMom[idx] + unitGainFactor * g;
smoothMom[idx] = g; smoothMom[idx] = g;
} }
@ -1483,8 +1484,9 @@ template <class ElemType>
__global__ void _fsadagrad4BlockSparseCol(CUDA_LONG size, __global__ void _fsadagrad4BlockSparseCol(CUDA_LONG size,
ElemType* grad_bsc, const GPUSPARSE_INDEX_TYPE* colOrRow2blockId, const size_t len, ElemType* grad_bsc, const GPUSPARSE_INDEX_TYPE* colOrRow2blockId, const size_t len,
ElemType* smoothAda, ElemType* smoothMom, ElemType* val, ElemType* smoothAda, ElemType* smoothMom, ElemType* val,
ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul) ElemType lr, ElemType mom, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum)
{ {
const ElemType unitGainFactor = unitGainMomentum ? (1.0 - mom) : 1.0;
CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x; CUDA_LONG idx = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_LONG stride = blockDim.x * gridDim.x; CUDA_LONG stride = blockDim.x * gridDim.x;
for (; idx < size; idx += stride) for (; idx < size; idx += stride)
@ -1511,7 +1513,7 @@ __global__ void _fsadagrad4BlockSparseCol(CUDA_LONG size,
if (mom > 0.0f) if (mom > 0.0f)
{ {
g = mom * smoothMom[idx] + (1.0f - mom) * g; g = mom * smoothMom[idx] + unitGainFactor * g;
smoothMom[idx] = g; smoothMom[idx] = g;
} }
@ -3980,8 +3982,10 @@ __global__ void _normalGradForSparseBlock(
const size_t numBlocks, const size_t numBlocks,
ElemType* lhsValues, // lhs is blockCol or blockRow ElemType* lhsValues, // lhs is blockCol or blockRow
const GPUSPARSE_INDEX_TYPE* blockIds, const GPUSPARSE_INDEX_TYPE* blockIds,
ElemType* rhs) ElemType* rhs,
bool unitGainMomentum)
{ {
const ElemType unitGainFactor = unitGainMomentum ? (1.0 - momentum) : 1.0;
const CUDA_LONG index = blockIdx.x * blockDim.x + threadIdx.x; const CUDA_LONG index = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_LONG row, col; CUDA_LONG row, col;
if (blockCol) if (blockCol)
@ -4000,7 +4004,7 @@ __global__ void _normalGradForSparseBlock(
col = index - numCols * blockId; col = index - numCols * blockId;
row = blockIds[blockId]; row = blockIds[blockId];
} }
rhs[IDX2C(row, col, numRows)] = (1 - momentum) * lhsValues[index] + momentum * rhs[IDX2C(row, col, numRows)]; rhs[IDX2C(row, col, numRows)] = unitGainFactor * lhsValues[index] + momentum * rhs[IDX2C(row, col, numRows)];
lhsValues[index] = rhs[IDX2C(row, col, numRows)]; lhsValues[index] = rhs[IDX2C(row, col, numRows)];
} }

Просмотреть файл

@ -1413,9 +1413,16 @@ GPUSparseMatrix<ElemType>& GPUSparseMatrix<ElemType>::InplaceSoftThreshold(const
return *this; return *this;
} }
// normal update for smoothed gradients c and current gradients (this) // A helper method used in MomentumSGDUpdate and NesterovAcceleratedMomentumSGDUpdate.
// Modifies the smoothed gradients "c", as well as the current gradients "this" on which this method is invoked.
// Classic momentum (unitGainFactor == 1.0):
// 1) c = momentum * c + this
// Unit-gain momentum (unitGainFactor == 1.0 - momentum):
// 1) c = momentum * c + (1.0 - momentum) * this
// 2) this = c
// TODO: NormalGrad is a misnomer here. Come up with a better name.
template <class ElemType> template <class ElemType>
void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum) void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum)
{ {
VerifyWritable(__FUNCTION__); VerifyWritable(__FUNCTION__);
@ -1440,7 +1447,8 @@ void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemTyp
GetBlockSize(), GetBlockSize(),
Data(), Data(),
BlockId2ColOrRow(), BlockId2ColOrRow(),
c.Data()); c.Data(),
unitGainMomentum);
} }
else else
{ {
@ -1512,7 +1520,8 @@ void GPUSparseMatrix<ElemType>::FSAdagrad(
ElemType learnRatePerSample, ElemType learnRatePerSample,
ElemType momentum, ElemType momentum,
ElemType adaWeight, ElemType adaWeight,
ElemType adaMul) ElemType adaMul,
bool unitGainMomentum)
{ {
if (GetFormat() != MatrixFormat::matrixFormatSparseBlockCol) if (GetFormat() != MatrixFormat::matrixFormatSparseBlockCol)
{ {
@ -1534,7 +1543,7 @@ void GPUSparseMatrix<ElemType>::FSAdagrad(
_fsadagrad4BlockSparseCol<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >( _fsadagrad4BlockSparseCol<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock >> >(
n, Data(), ColOrRow2BlockId(), GetNumRows(), n, Data(), ColOrRow2BlockId(), GetNumRows(),
c.Data(), c.Data() + n, functionValues.Data(), c.Data(), c.Data() + n, functionValues.Data(),
learnRatePerSample, momentum, adaWeight, adaMul); learnRatePerSample, momentum, adaWeight, adaMul, unitGainMomentum);
} }
template <class ElemType> template <class ElemType>

Просмотреть файл

@ -408,9 +408,9 @@ public:
const bool transposeB, ElemType beta, GPUMatrix<ElemType>& c, size_t numChannels, size_t horizontalSubsample, bool padding, bool channelwise); const bool transposeB, ElemType beta, GPUMatrix<ElemType>& c, size_t numChannels, size_t horizontalSubsample, bool padding, bool channelwise);
static void TensorShuffleScaleAndAdd(ElemType keepWeight, const GPUSparseMatrix<ElemType>& a, size_t D, size_t S, size_t M, size_t K, size_t T, ElemType scaleFactor, const GPUSparseMatrix<ElemType>& b, GPUSparseMatrix<ElemType>& c); static void TensorShuffleScaleAndAdd(ElemType keepWeight, const GPUSparseMatrix<ElemType>& a, size_t D, size_t S, size_t M, size_t K, size_t T, ElemType scaleFactor, const GPUSparseMatrix<ElemType>& b, GPUSparseMatrix<ElemType>& c);
void NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum); void NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum = true);
ElemType Adagrad(GPUMatrix<ElemType>& c, const bool needAveMultiplier); ElemType Adagrad(GPUMatrix<ElemType>& c, const bool needAveMultiplier);
void FSAdagrad(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul); void FSAdagrad(GPUMatrix<ElemType>& c, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum);
ElemType RmsProp(GPUMatrix<ElemType>& c, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier); ElemType RmsProp(GPUMatrix<ElemType>& c, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
static void Multiply(const GPUSparseMatrix<ElemType>& S, const GPUMatrix<ElemType>& D, GPUMatrix<ElemType>& C); static void Multiply(const GPUSparseMatrix<ElemType>& S, const GPUMatrix<ElemType>& D, GPUMatrix<ElemType>& C);

Просмотреть файл

@ -1484,70 +1484,137 @@ void Matrix<ElemType>::SetUniformRandomMask(const ElemType maskRate, const ElemT
NOT_IMPLEMENTED); NOT_IMPLEMENTED);
} }
// Vanilla SGD update.
// Modifies "this" parameter matrix, on which this method is invoked.
template <class ElemType> template <class ElemType>
void Matrix<ElemType>::NormalGrad(Matrix<ElemType>& gradients, void Matrix<ElemType>::SGDUpdate(Matrix<ElemType>& gradients, ElemType learnRatePerSample)
Matrix<ElemType>& functionValues,
const ElemType learnRatePerSample,
const ElemType momentum,
const bool useNesterovMomentum)
{ {
DecideAndMoveToRightDevice(*this, gradients, functionValues); DecideAndMoveToRightDevice(gradients, *this);
if (!useNesterovMomentum) DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
{ {
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr, // w_t = w_{t-1} - learnRatePerSample * g_{t-1},
{ ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this);
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this); },
functionValues -= *this; {
}, // BUGBUG: cannot call ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this) here,
{ // it produces different results from the scale and add below.
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this); // g'_{t-1} = learnRatePerSample * g_{t-1}
functionValues -= *this; // w_t = w_{t-1} - g'_{t-1}
}, Scale(ElemType(learnRatePerSample), gradients);
{ *this -= gradients;
if (momentum != 0) gradients.m_CPUSparseMatrix->NormalGrad(*m_CPUMatrix, momentum); },
ScaleAndAdd(-learnRatePerSample, gradients, functionValues); {
}, ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this);
{ },
if (momentum != 0) gradients.m_GPUSparseMatrix->NormalGrad(*m_GPUMatrix, momentum); {
ScaleAndAdd(-learnRatePerSample, gradients, functionValues); ScaleAndAdd(ElemType(-learnRatePerSample), gradients, *this);
}); });
}
else }
{
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr, // SGD update with momentum.
{ /* CPU dense */ // Modifies "this" parameter matrix, on which this method is invoked.
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this); template <class ElemType>
ScaleAndAdd(-momentum, *this, functionValues); void Matrix<ElemType>::MomentumSGDUpdate(Matrix<ElemType>& gradients,
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradients, functionValues); Matrix<ElemType>& smoothedGradients,
// w_t = w_{t-1} - momentum * v_ {t-1} - (1-momentum)*learnRatePerSampele*gardient, ElemType learnRatePerSample,
}, ElemType momentum,
{ /* GPU dense */ bool unitGainMomentum)
ScaleAndAdd((1 - momentum) * learnRatePerSample, gradients, momentum, *this); {
ScaleAndAdd(-momentum, *this, functionValues); DecideAndMoveToRightDevice(smoothedGradients, gradients, *this);
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradients, functionValues);
}, const auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
{ /* CPU sparse */
if (momentum != 0) DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
{ {
Matrix<ElemType> gradientCache(gradients.GetDeviceId()); // Classic momentum (unitGainFactor == 1.0):
gradientCache.AssignValuesOf(gradients); // 1) sg_t = momentum * sg_{t-1} + learnRatePerSample * g_{t-1}
gradients.m_CPUSparseMatrix->NormalGrad(*m_CPUMatrix, momentum); // Unit-gain momentum (unitGainFactor == 1.0 - momentum):
ScaleAndAdd(-momentum, *this, functionValues); // 1) sg_t = momentum * sg_{t-1} + learnRatePerSample * (1.0 - momentum) * g_{t-1}
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradientCache, functionValues); // 2) w_t = w_{t-1} - sg_t
} ScaleAndAdd(unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
}, *this -= smoothedGradients;
{ /* GPU sparse */ },
if (momentum != 0) {
{ ScaleAndAdd(unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
Matrix<ElemType> gradientCache(gradients.GetDeviceId()); *this -= smoothedGradients;
gradientCache.AssignValuesOf(gradients); },
gradients.m_GPUSparseMatrix->NormalGrad(*m_GPUMatrix, momentum); {
ScaleAndAdd(-momentum, *this, functionValues); // The sparse update is slightly different from the dense implementation above:
ScaleAndAdd(-(1 - momentum) * learnRatePerSample, gradientCache, functionValues); // Classic momentum (unitGainFactor == 1.0):
} // 1) sg_t = momentum * sg_{t-1} + g_{t-1}
}); // Unit-gain momentum (unitGainFactor == 1.0 - momentum):
} // 1) sg_t = momentum * sg_{t-1} + (1.0 - momentum) * g_{t-1}
// 2) g'_{t-1} = sg_t
// 3) w_t = w_{t-1} - learnRatePerSample * g'_{t-1}
if (momentum != 0)
{
gradients.m_CPUSparseMatrix->NormalGrad(*smoothedGradients.m_CPUMatrix, momentum, unitGainMomentum);
}
ScaleAndAdd(-learnRatePerSample, gradients, *this);
},
{
if (momentum != 0)
{
gradients.m_GPUSparseMatrix->NormalGrad(*smoothedGradients.m_GPUMatrix, momentum, unitGainMomentum);
}
ScaleAndAdd(-learnRatePerSample, gradients, *this);
});
}
// Nesterov accelerated SGD update.
// Modifies "this" parameter matrix, on which this method is invoked.
template <class ElemType>
void Matrix<ElemType>::NesterovAcceleratedMomentumSGDUpdate(Matrix<ElemType>& gradients,
Matrix<ElemType>& smoothedGradients,
ElemType learnRatePerSample,
ElemType momentum,
bool unitGainMomentum)
{
DecideAndMoveToRightDevice(smoothedGradients, gradients, *this);
const auto unitGainFactor = ElemType(unitGainMomentum ? (1.0 - momentum) : 1.0);
DISPATCH_MATRIX_ON_FLAG(&gradients, nullptr,
{ /* CPU dense */
// 1) sg_t = momentum * sg_{t-1} + learnRatePerSample * unitGainFactor * g_{t-1}
// 2) w'_t = w_{t-1} - momentum * sg_t
// 3) w_t = w'_t - learnRatePerSample * unitGainFactor * g_{t-1}
// The end result:
// w_t = w_{t-1} - momentum^2 * sg_{t-1} - learnRatePerSample * unitGainFactor * (1 + momentum) * g_{t-1}
// sg_t = momentum * sg_{t-1} + learnRatePerSample * unitGainFactor * g_{t-1}
ScaleAndAdd( unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
ScaleAndAdd(-momentum, smoothedGradients, *this);
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradients, *this);
},
{ /* GPU dense */
ScaleAndAdd(unitGainFactor * learnRatePerSample, gradients, momentum, smoothedGradients);
ScaleAndAdd(-momentum, smoothedGradients, *this);
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradients, *this);
},
{ /* CPU sparse */
if (momentum != 0)
{
// Identical to the above, except that as a side effect "NormalGrad" modifies
// gradient values in place, so that gradientCache is needed to store the original values.
Matrix<ElemType> gradientCache(gradients.GetDeviceId());
gradientCache.AssignValuesOf(gradients);
gradients.m_CPUSparseMatrix->NormalGrad(*smoothedGradients.m_CPUMatrix, momentum, unitGainMomentum);
ScaleAndAdd(-momentum, smoothedGradients, *this);
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradientCache, *this);
}
},
{ /* GPU sparse */
if (momentum != 0)
{
Matrix<ElemType> gradientCache(gradients.GetDeviceId());
gradientCache.AssignValuesOf(gradients);
gradients.m_GPUSparseMatrix->NormalGrad(*smoothedGradients.m_GPUMatrix, momentum, unitGainMomentum);
ScaleAndAdd(-momentum, smoothedGradients, *this);
ScaleAndAdd(-unitGainFactor * learnRatePerSample, gradientCache, *this);
}
});
} }
// both 'this' and gradients will be changed // both 'this' and gradients will be changed
@ -1575,7 +1642,7 @@ template <class ElemType>
void Matrix<ElemType>::FSAdagradUpdate(size_t mbSize, void Matrix<ElemType>::FSAdagradUpdate(size_t mbSize,
Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, double& smoothedCount, Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, double& smoothedCount,
const double learnRatePerSample, const double targetAdagradAvDenom, const double learnRatePerSample, const double targetAdagradAvDenom,
const double meanMomentum, const double varMomentum) const double meanMomentum, const double varMomentum, bool unitGainMomentum)
{ {
// keep track on how many samples have been accumulated into the g^2 accumulator // keep track on how many samples have been accumulated into the g^2 accumulator
smoothedCount = varMomentum * smoothedCount + (1.0 - varMomentum) * mbSize; smoothedCount = varMomentum * smoothedCount + (1.0 - varMomentum) * mbSize;
@ -1587,10 +1654,20 @@ void Matrix<ElemType>::FSAdagradUpdate(size_t mbSize,
let targetAdagradAvDenom_x_sqrtAdagradSqrFrames = (ElemType)(targetAdagradAvDenom * sqrt(smoothedCount)); let targetAdagradAvDenom_x_sqrtAdagradSqrFrames = (ElemType)(targetAdagradAvDenom * sqrt(smoothedCount));
DISPATCH_MATRIX_ON_FLAG(&gradients, &gradients, DISPATCH_MATRIX_ON_FLAG(&gradients, &gradients,
{ m_CPUMatrix->FSAdagrad(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames); SetDataLocation(CPU); }, {
{ m_GPUMatrix->FSAdagrad(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames); SetDataLocation(GPU); }, m_CPUMatrix->FSAdagrad(*gradients.m_CPUMatrix, *functionValues.m_CPUMatrix,
(ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum,
targetAdagradAvDenom_x_sqrtAdagradSqrFrames, unitGainMomentum);
SetDataLocation(CPU);
},
{
m_GPUMatrix->FSAdagrad(*gradients.m_GPUMatrix, *functionValues.m_GPUMatrix,
(ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum,
targetAdagradAvDenom_x_sqrtAdagradSqrFrames, unitGainMomentum);
SetDataLocation(GPU);
},
{ NOT_IMPLEMENTED; }, { NOT_IMPLEMENTED; },
{ gradients.m_GPUSparseMatrix->FSAdagrad(*m_GPUMatrix, *functionValues.m_GPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames); SetDataLocation(GPU); }); { gradients.m_GPUSparseMatrix->FSAdagrad(*m_GPUMatrix, *functionValues.m_GPUMatrix, (ElemType)learnRatePerSample, (ElemType)meanMomentum, (ElemType)varMomentum, targetAdagradAvDenom_x_sqrtAdagradSqrFrames, unitGainMomentum); SetDataLocation(GPU); });
// Note: Since both 'this' and gradients are changed, we must call SetDataLocation() on 'this' as well. // Note: Since both 'this' and gradients are changed, we must call SetDataLocation() on 'this' as well.
} }
@ -4796,6 +4873,9 @@ template <class ElemType>
{ {
ScaleAndAdd(alpha / beta, a, c); // c1=alpha/beta * a + c ScaleAndAdd(alpha / beta, a, c); // c1=alpha/beta * a + c
Scale(beta, c); // c/beta * beta Scale(beta, c); // c/beta * beta
// TODO: two lines above should be changed as follows:
// Scale(beta, c); // c1 = c * beta
// ScaleAndAdd(alpha, a, c); // c=alpha * a + c1 = alpha * a + beta * c
} }
} }

Просмотреть файл

@ -206,13 +206,15 @@ public:
Matrix<ElemType> Diagonal() const; Matrix<ElemType> Diagonal() const;
void AssignDiagonalValuesTo(Matrix<ElemType>& diag) const; void AssignDiagonalValuesTo(Matrix<ElemType>& diag) const;
// TODO: all these scalars should be passed as doubles and cast down inside void SGDUpdate(Matrix<ElemType>& gradients, ElemType learnRatePerSample);
void NormalGrad(Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, const ElemType learnRatePerSample, const ElemType momentum, const bool useNAG); void MomentumSGDUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& smoothedGradients, ElemType learnRatePerSample, ElemType momentum, bool unitGainMomentum = true);
void NesterovAcceleratedMomentumSGDUpdate(Matrix<ElemType>& gradients, Matrix<ElemType>& smoothedGradients, ElemType learnRatePerSample, ElemType momentum, bool unitGainMomentum = true);
ElemType Adagrad(Matrix<ElemType>& gradients, const bool needAveMultiplier); ElemType Adagrad(Matrix<ElemType>& gradients, const bool needAveMultiplier);
void FSAdagradUpdate(size_t mbSize, void FSAdagradUpdate(size_t mbSize,
Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, double& smoothedCount, Matrix<ElemType>& gradients, Matrix<ElemType>& functionValues, double& smoothedCount,
const double learnRatePerSample, const double targetAdagradAvDenom, const double learnRatePerSample, const double targetAdagradAvDenom,
const double meanMomentum, const double varMomentum); const double meanMomentum, const double varMomentum, bool unitGainMomentum = true);
ElemType RmsProp(Matrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier); ElemType RmsProp(Matrix<ElemType>& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier);
void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 10000, bool growOnly = true); // by default we only reallocate if need to grow void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 10000, bool growOnly = true); // by default we only reallocate if need to grow

Просмотреть файл

@ -247,7 +247,7 @@ GPUSparseMatrix<ElemType>& GPUSparseMatrix<ElemType>::InplaceTruncate(const Elem
// normal update for smoothed gradients c and current gradients (this) // normal update for smoothed gradients c and current gradients (this)
template <class ElemType> template <class ElemType>
void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum) void GPUSparseMatrix<ElemType>::NormalGrad(GPUMatrix<ElemType>& c, const ElemType momentum, bool unitGainMomentum)
{ {
} }
template <class ElemType> template <class ElemType>
@ -257,7 +257,7 @@ ElemType GPUSparseMatrix<ElemType>::Adagrad(GPUMatrix<ElemType>& c, const bool n
} }
template<class ElemType> template<class ElemType>
void GPUSparseMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>&, GPUMatrix<ElemType>&, ElemType, ElemType, ElemType, ElemType) void GPUSparseMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>&, GPUMatrix<ElemType>&, ElemType, ElemType, ElemType, ElemType, bool)
{ {
} }
@ -1068,7 +1068,7 @@ ElemType GPUMatrix<ElemType>::Adagrad(GPUMatrix<ElemType>& gradients, const bool
} }
template <class ElemType> template <class ElemType>
void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul) void GPUMatrix<ElemType>::FSAdagrad(GPUMatrix<ElemType>& gradients, GPUMatrix<ElemType>& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum)
{ {
} }

Просмотреть файл

@ -2149,7 +2149,7 @@ void SGD<ElemType>::InitModelAggregationHandler(int traceLevel, DEVICEID_TYPE de
// UpdateWeights() - actual weight update, implementing various update rules // UpdateWeights() - actual weight update, implementing various update rules
template <class ElemType> template <class ElemType>
void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemType>& gradientValues, void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemType>& gradientValues,
Matrix<ElemType>& smoothedGradient, double& smoothedCount, Matrix<ElemType>& smoothedGradientValues, double& smoothedCount,
const double learnRatePerSample, const double momentumPerSample, const double learnRatePerSample, const double momentumPerSample,
size_t actualMBSize, size_t actualMBSize,
const double L2RegWeight, const double L1RegWeight, const double L2RegWeight, const double L1RegWeight,
@ -2164,7 +2164,7 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
LOGPRINTF(stderr, "GradUpdateType()=%d, GradientUpdateNoiseStd()=%0.8f\n", LOGPRINTF(stderr, "GradUpdateType()=%d, GradientUpdateNoiseStd()=%0.8f\n",
GradUpdateType(), GradientUpdateNoiseStd()); GradUpdateType(), GradientUpdateNoiseStd());
gradientValues.Print("Gradient Input"); gradientValues.Print("Gradient Input");
smoothedGradient.Print("Smoothed Gradient Input"); smoothedGradientValues.Print("Smoothed Gradient Input");
#endif #endif
// make actualMBSize is a valid value // make actualMBSize is a valid value
@ -2194,12 +2194,23 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
if (adpType == GradientsUpdateType::None) if (adpType == GradientsUpdateType::None)
{ {
smoothedGradient.NormalGrad(gradientValues, functionValues, // even if momentum is 0.0, still need to call a momentum-based update to store
(ElemType) learnRatePerSample, (ElemType) momentum, useNesterovMomentum); // [learning rate * current gradient values] in the smoothed gradients, in case
// the momentum value for the next epoch is non-zero.
if (!useNesterovMomentum)
{
functionValues.MomentumSGDUpdate(gradientValues, smoothedGradientValues,
ElemType(learnRatePerSample), ElemType(momentum));
}
else
{
functionValues.NesterovAcceleratedMomentumSGDUpdate(gradientValues, smoothedGradientValues,
ElemType(learnRatePerSample), ElemType(momentum));
}
} }
else if (adpType == GradientsUpdateType::AdaGrad) else if (adpType == GradientsUpdateType::AdaGrad)
{ {
double aveMultiplier = smoothedGradient.Adagrad(gradientValues, needAveMultiplier); double aveMultiplier = smoothedGradientValues.Adagrad(gradientValues, needAveMultiplier);
Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues); Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues);
} }
else if (adpType == GradientsUpdateType::FSAdaGrad) else if (adpType == GradientsUpdateType::FSAdaGrad)
@ -2209,14 +2220,14 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
static double smoothedCount = 0; static double smoothedCount = 0;
#endif #endif
smoothedGradient.FSAdagradUpdate(actualMBSize, smoothedGradientValues.FSAdagradUpdate(actualMBSize,
gradientValues, functionValues, smoothedCount, gradientValues, functionValues, smoothedCount,
learnRatePerSample, m_gradType.targetAdagradAvDenom, learnRatePerSample, m_gradType.targetAdagradAvDenom,
momentum, varMomentum); momentum, varMomentum);
} }
else if (adpType == GradientsUpdateType::RmsProp) else if (adpType == GradientsUpdateType::RmsProp)
{ {
double aveMultiplier = smoothedGradient.RmsProp(gradientValues, (ElemType) m_rpi.gamma, double aveMultiplier = smoothedGradientValues.RmsProp(gradientValues, (ElemType) m_rpi.gamma,
(ElemType) m_rpi.inc, (ElemType) m_rpi.max, (ElemType) m_rpi.inc, (ElemType) m_rpi.max,
(ElemType) m_rpi.dec, (ElemType) m_rpi.min, needAveMultiplier); (ElemType) m_rpi.dec, (ElemType) m_rpi.min, needAveMultiplier);
Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues); Matrix<ElemType>::ScaleAndAdd((ElemType)(-learnRatePerSample / aveMultiplier), gradientValues, functionValues);
@ -2299,8 +2310,8 @@ void SGD<ElemType>::SaveCheckPointInfo(const size_t epoch, const size_t totalSam
for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++) for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++)
{ {
const Matrix<ElemType>& smoothedGradient = *smoothedGradientIter; const Matrix<ElemType>& smoothedGradientValues = *smoothedGradientIter;
fstream << smoothedGradient; fstream << smoothedGradientValues;
} }
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EGradient"); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EGradient");
@ -2395,8 +2406,8 @@ void SGD<ElemType>::LoadCheckPointInfo(const size_t epochNumber,
for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++) for (auto smoothedGradientIter = smoothedGradients.begin(); smoothedGradientIter != smoothedGradients.end(); smoothedGradientIter++)
{ {
Matrix<ElemType>& smoothedGradient = *smoothedGradientIter; Matrix<ElemType>& smoothedGradientValues = *smoothedGradientIter;
fstream >> smoothedGradient; fstream >> smoothedGradientValues;
} }
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EGradient"); fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EGradient");

Просмотреть файл

@ -394,6 +394,7 @@ class Test:
ccMajorByCard = { ccMajorByCard = {
'GeForce GTX 780 Ti': 3, 'GeForce GTX 780 Ti': 3,
'GeForce GTX 960': 5, 'GeForce GTX 960': 5,
'Quadro K2000' : 3,
'Quadro M2000M': 5, 'Quadro M2000M': 5,
'Quadro M4000': 5, 'Quadro M4000': 5,
} }

Просмотреть файл

@ -6409,7 +6409,17 @@ Test module "MathTests" has passed with:
Test case "MatrixUnitTests/MatrixAssignNumOfDiff" has passed with: Test case "MatrixUnitTests/MatrixAssignNumOfDiff" has passed with:
2 assertions out of 2 passed 2 assertions out of 2 passed
Test case "MatrixUnitTests/MatrixScale" has passed
Test case "MatrixUnitTests/MatrixSGDUpdate" has passed
Test case "MatrixUnitTests/MatrixMomentumSGDUpdate_WithAndWithout_UnitGain" has passed
Test case "MatrixUnitTests/MatrixNesterovAcceleratedMomentumSGDUpdate_WithAndWithout_UnitGain" has passed
Test case "MatrixUnitTests/MatrixFSAdagradUpdate_WithAndWithout_UnitGain" has passed
Test suite "QuantizersUnitTests" has passed with: Test suite "QuantizersUnitTests" has passed with:
2 test cases out of 2 passed 2 test cases out of 2 passed
12 assertions out of 12 passed 12 assertions out of 12 passed

Просмотреть файл

@ -1182,6 +1182,218 @@ BOOST_FIXTURE_TEST_CASE(MatrixAssignNumOfDiff, RandomSeedFixture)
BOOST_CHECK_EQUAL(expectedDiff, actual.Get00Element()); BOOST_CHECK_EQUAL(expectedDiff, actual.Get00Element());
} }
} }
BOOST_FIXTURE_TEST_CASE(MatrixScale, RandomSeedFixture)
{
const float low = -1.0f;
const float high = 1.0f;
float alpha = 0.7713f;
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
{
auto a1 = SingleMatrix::RandomUniform(7, 11, deviceId, low, high, IncrementCounter());
auto a2 = a1.DeepClone();
BOOST_ASSERT(a1.IsEqualTo(a2));
auto b1 = SingleMatrix::RandomUniform(7, 11, deviceId, low, high, IncrementCounter());
auto b2 = b1.DeepClone();
BOOST_ASSERT(b1.IsEqualTo(b2));
Matrix<float>::ScaleAndAdd(alpha, b1, a1);
Matrix<float>::Scale(alpha, b2);
a2 += b2;
// BUGBUG: this test currently fails on GPU.
if (deviceId != CPUDEVICE)
continue;
BOOST_CHECK(a1.IsEqualTo(a2));
}
}
BOOST_FIXTURE_TEST_CASE(MatrixSGDUpdate, RandomSeedFixture)
{
const float low = -1.0f;
const float high = 1.0f;
float lr = 0.77f;
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
{
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto p2 = p1.DeepClone();
BOOST_ASSERT(p1.IsEqualTo(p2));
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto g2 = g1.DeepClone();
BOOST_ASSERT(g1.IsEqualTo(g2));
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto sg2 = sg1.DeepClone();
BOOST_ASSERT(sg1.IsEqualTo(sg2));
for (; lr > 0.01; lr = lr / 2)
{
if (deviceId != CPUDEVICE)
{
// g1 is modified inside the GPU version of SGDUpdate, restore the original value here.
g1.SetValue(g2);
}
p1.SGDUpdate(g1, lr);
p2.MomentumSGDUpdate(g2, sg2, lr, 0.0);
BOOST_CHECK(p1.IsEqualTo(p2));
if (deviceId != CPUDEVICE)
continue;
// GPU version of SGDUpdate scales gradient by the learning rate, this check will fail.
BOOST_CHECK(g1.IsEqualTo(g2));
}
lr = std::pow(lr, lr);
}
}
BOOST_FIXTURE_TEST_CASE(MatrixMomentumSGDUpdate_WithAndWithout_UnitGain, RandomSeedFixture)
{
const float low = -1.0f;
const float high = 1.0f;
float lr = 0.77f;
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
{
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto p2 = p1.DeepClone();
BOOST_ASSERT(p1.IsEqualTo(p2));
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto g2 = g1.DeepClone();
BOOST_ASSERT(g1.IsEqualTo(g2));
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto sg2 = sg1.DeepClone();
BOOST_ASSERT(sg1.IsEqualTo(sg2));
for (; lr > 0.01; lr = lr / 2)
{
p1.MomentumSGDUpdate(g1, sg1, lr, 0.0, true);
p2.MomentumSGDUpdate(g2, sg2, lr, 0.0, false);
BOOST_CHECK(p1.IsEqualTo(p2));
}
for (lr = 1.0; lr > 0.03; lr = lr / 2)
{
p1.MomentumSGDUpdate(g1, sg1, lr, 0.5, true);
p2.MomentumSGDUpdate(g2, sg2, lr/2, 0.5, false);
BOOST_CHECK(p1.IsEqualTo(p2));
}
BOOST_CHECK(g1.IsEqualTo(g2));
BOOST_CHECK(sg1.IsEqualTo(sg2));
p1.MomentumSGDUpdate(g1, sg1, lr, 0.5, true);
p2.MomentumSGDUpdate(g2, sg2, lr, 0.5, false);
BOOST_CHECK(!p1.IsEqualTo(p2));
lr = std::pow(lr, lr);
}
}
BOOST_FIXTURE_TEST_CASE(MatrixNesterovAcceleratedMomentumSGDUpdate_WithAndWithout_UnitGain, RandomSeedFixture)
{
const float low = -1.0f;
const float high = 1.0f;
float lr = 0.77f;
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
{
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto p2 = p1.DeepClone();
BOOST_ASSERT(p1.IsEqualTo(p2));
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto g2 = g1.DeepClone();
BOOST_ASSERT(g1.IsEqualTo(g2));
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto sg2 = sg1.DeepClone();
BOOST_ASSERT(sg1.IsEqualTo(sg2));
for (; lr > 0.01; lr = lr / 2)
{
p1.NesterovAcceleratedMomentumSGDUpdate(g1, sg1, lr, 0.0, true);
p2.NesterovAcceleratedMomentumSGDUpdate(g2, sg2, lr, 0.0, false);
BOOST_CHECK(p1.IsEqualTo(p2));
}
for (lr = 1.0; lr > 0.03; lr = lr / 2)
{
p1.NesterovAcceleratedMomentumSGDUpdate(g1, sg1, lr, 0.5, true);
p2.NesterovAcceleratedMomentumSGDUpdate(g2, sg2, lr/2, 0.5, false);
BOOST_CHECK(p1.IsEqualTo(p2));
}
BOOST_CHECK(g1.IsEqualTo(g2));
BOOST_CHECK(sg1.IsEqualTo(sg2));
p1.NesterovAcceleratedMomentumSGDUpdate(g1, sg1, lr, 0.5, true);
p2.NesterovAcceleratedMomentumSGDUpdate(g2, sg2, lr, 0.5, false);
BOOST_CHECK(!p1.IsEqualTo(p2));
lr = std::pow(lr, lr);
}
}
BOOST_FIXTURE_TEST_CASE(MatrixFSAdagradUpdate_WithAndWithout_UnitGain, RandomSeedFixture)
{
const float low = -1.0f;
const float high = 1.0f;
float lr = 0.77f;
for (auto deviceId : {CPUDEVICE, c_deviceIdZero})
{
auto p1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto p2 = p1.DeepClone();
BOOST_ASSERT(p1.IsEqualTo(p2));
auto g1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto g2 = g1.DeepClone();
BOOST_ASSERT(g1.IsEqualTo(g2));
auto sg1 = SingleMatrix::RandomUniform(12, 13, deviceId, low, high, IncrementCounter());
auto sg2 = sg1.DeepClone();
BOOST_ASSERT(sg1.IsEqualTo(sg2));
for (; lr > 0.01; lr = lr / 2)
{
size_t mbSize = 100;
double smoothedCount = 10 / lr;
double targetAdagradAvDenom = 1.0;
double varMomentum = 1.0 - lr;
sg1.FSAdagradUpdate(mbSize, g1, p1, smoothedCount, lr, targetAdagradAvDenom, 0.0, varMomentum, true);
sg2.FSAdagradUpdate(mbSize, g2, p2, smoothedCount, lr, targetAdagradAvDenom, 0.0, varMomentum, true /*false*/);
// BUGBUG: at the moment this fails even with identical arguments.
// BOOST_CHECK(p1.IsEqualTo(p2));
}
sg2.SetValue(sg1);
BOOST_ASSERT(sg1.IsEqualTo(sg2));
for (lr = 1.0; lr > 0.03; lr = lr / 2)
{
size_t mbSize = 100;
double smoothedCount = 10 / lr;
double targetAdagradAvDenom = 1.0;
double varMomentum = 1.0 - lr;
sg1.FSAdagradUpdate(mbSize, g1, p1, smoothedCount, lr, targetAdagradAvDenom, 0.5, varMomentum, true);
sg2.FSAdagradUpdate(mbSize, g2, p2, smoothedCount, lr /*lr/2*/, targetAdagradAvDenom, 0.5, varMomentum, true /*false*/);
// BUGBUG: at the moment this fails even with identical arguments.
// BOOST_CHECK(p1.IsEqualTo(p2));
}
lr = std::pow(lr, lr);
}
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
} }
} } } } } }

Просмотреть файл

@ -56,24 +56,24 @@ void TestSGDLearner(size_t numParameters, size_t numMinibatches, const DeviceDes
} }
template <typename ElementType> template <typename ElementType>
void TestMomentumSGDLearner(size_t numParameters, size_t numMinibatches, const DeviceDescriptor& device) void TestMomentumSGDLearner(size_t numParameters, size_t numMinibatches, bool unitGainMomentum, const DeviceDescriptor& device)
{ {
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize); NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device); auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
LearningRatePerMinibatchSchedule learnigRateSchedule = { { 3.0, 2.0, 1.0 }, numMinibatches }; LearningRatePerMinibatchSchedule learnigRateSchedule = { { 3.0, 2.0, 1.0 }, numMinibatches };
MomentumPerSampleSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 }; MomentumPerSampleSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 };
auto learner = MomentumSGDLearner(parameters, learnigRateSchedule, momentumValues); auto learner = MomentumSGDLearner(parameters, learnigRateSchedule, momentumValues, unitGainMomentum);
TestUpdate<ElementType>(learner, shape, numMinibatches, device); TestUpdate<ElementType>(learner, shape, numMinibatches, device);
FloatingPointCompare(learner->LearningRate(), 2.0, "Learner::LearningRate does not match expectation"); FloatingPointCompare(learner->LearningRate(), 2.0, "Learner::LearningRate does not match expectation");
} }
template <typename ElementType> template <typename ElementType>
void TestNesterovLearner(size_t numParameters, size_t numMinibatches, const DeviceDescriptor& device) void TestNesterovLearner(size_t numParameters, size_t numMinibatches, bool unitGainMomentum, const DeviceDescriptor& device)
{ {
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize); NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device); auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
MomentumAsTimeConstantSchedule momentumValues = { { { 1, 1 }, { 3, 5 }, { 10, 25 } }, 100 }; MomentumAsTimeConstantSchedule momentumValues = { { { 1, 1 }, { 3, 5 }, { 10, 25 } }, 100 };
auto learner = NesterovLearner(parameters, LearningRatePerMinibatchSchedule( { { 1, 0.5 }, { 10, 0.25 }, { 20, 0.125 } }, 3 ), momentumValues); auto learner = NesterovLearner(parameters, LearningRatePerMinibatchSchedule( { { 1, 0.5 }, { 10, 0.25 }, { 20, 0.125 } }, 3 ), momentumValues, unitGainMomentum);
TestUpdate<ElementType>(learner, shape, numMinibatches, device); TestUpdate<ElementType>(learner, shape, numMinibatches, device);
} }
@ -87,11 +87,11 @@ void TestAdaGradLearner(size_t numParameters, size_t numMinibatches, const Devic
} }
template <typename ElementType> template <typename ElementType>
void TestFSAdaGradLearner(size_t numParameters, size_t numMinibatches, const DeviceDescriptor& device) void TestFSAdaGradLearner(size_t numParameters, size_t numMinibatches, bool unitGainMomentum, const DeviceDescriptor& device)
{ {
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize); NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device); auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
auto learner = AdamLearner(parameters, LearningRatePerSampleSchedule({ 0.5 }), MomentumAsTimeConstantSchedule({ 10.0, 100.0, 1000.0 })); auto learner = AdamLearner(parameters, LearningRatePerSampleSchedule({ 0.5 }), MomentumAsTimeConstantSchedule({ 10.0, 100.0, 1000.0 }), unitGainMomentum);
TestUpdate<ElementType>(learner, shape, numMinibatches, device); TestUpdate<ElementType>(learner, shape, numMinibatches, device);
} }
@ -235,21 +235,33 @@ void LearnerTests()
TestTrainingParametersSchedule(); TestTrainingParametersSchedule();
TestSGDLearner<double>(5, 3, DeviceDescriptor::CPUDevice()); vector<DeviceDescriptor> devices{DeviceDescriptor::CPUDevice()};
TestMomentumSGDLearner<float>(3, 4, DeviceDescriptor::CPUDevice());
TestNesterovLearner<float>(1, 4, DeviceDescriptor::CPUDevice());
TestAdaGradLearner<double>(2, 5, DeviceDescriptor::CPUDevice());
TestFSAdaGradLearner<double>(10, 2, DeviceDescriptor::CPUDevice());
TestRMSPropLearner<float>(3, 3, DeviceDescriptor::CPUDevice());
if (IsGPUAvailable()) if (IsGPUAvailable())
{ {
TestSGDLearner<double>(1, 1, DeviceDescriptor::CPUDevice()); devices.push_back(DeviceDescriptor::GPUDevice(0));
TestMomentumSGDLearner<float>(3, 5, DeviceDescriptor::GPUDevice(0));
TestNesterovLearner<float>(1, 4, DeviceDescriptor::GPUDevice(0));
TestAdaGradLearner<double>(1, 2, DeviceDescriptor::GPUDevice(0));
TestFSAdaGradLearner<double>(2, 2, DeviceDescriptor::GPUDevice(0));
TestRMSPropLearner<float>(3, 3, DeviceDescriptor::GPUDevice(0));
} }
srand(1);
for (auto& device : devices)
{
auto numParameters = 1 + rand() % 5;
auto numMinibatches = 1 + rand() % 5;
TestSGDLearner<double>(numParameters, numMinibatches, device);
TestAdaGradLearner<double>(numParameters, numMinibatches, device);
TestRMSPropLearner<float>(numParameters, numMinibatches, device);
}
for (auto& device : devices)
{
for (auto unitGain : { true, false })
{
auto numParameters = 1 + rand() % 5;
auto numMinibatches = 1 + rand() % 5;
TestMomentumSGDLearner<float>(numParameters, numMinibatches, unitGain, device);
TestNesterovLearner<float>(numParameters, numMinibatches, unitGain, device);
TestFSAdaGradLearner<double>(numParameters, numMinibatches, unitGain, device);
}
}
} }

Просмотреть файл

@ -180,7 +180,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
AdditionalLearningOptions additionalOptions; AdditionalLearningOptions additionalOptions;
additionalOptions.gradientClippingThresholdPerSample = 2.3; additionalOptions.gradientClippingThresholdPerSample = 2.3;
additionalOptions.gradientClippingWithTruncation = true; additionalOptions.gradientClippingWithTruncation = true;
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumTimeConstant, additionalOptions) }); Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true, additionalOptions) });
size_t outputFrequencyInMinibatches = 1; size_t outputFrequencyInMinibatches = 1;
size_t minibatchSize1 = 72; size_t minibatchSize1 = 72;

Просмотреть файл

@ -48,7 +48,9 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool useSparseLa
LearningRatePerSampleSchedule learningRatePerSample = 0.0005; LearningRatePerSampleSchedule learningRatePerSample = 0.0005;
MomentumAsTimeConstantSchedule momentumTimeConstant = 256; MomentumAsTimeConstantSchedule momentumTimeConstant = 256;
Trainer trainer(classifierOutput, trainingLoss, prediction, { MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant) }); Trainer trainer(classifierOutput, trainingLoss, prediction,
{ MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample,
momentumTimeConstant, /*unitGainMomentum = */true) });
size_t outputFrequencyInMinibatches = 1; size_t outputFrequencyInMinibatches = 1;
for (size_t i = 0; true; i++) for (size_t i = 0; true; i++)

Просмотреть файл

@ -416,7 +416,7 @@ Trainer BuildTrainer(const FunctionPtr& function, const Variable& labels,
{ {
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(function, labels, L"lossFunction"); auto trainingLoss = CNTK::CrossEntropyWithSoftmax(function, labels, L"lossFunction");
auto prediction = CNTK::ClassificationError(function, labels, L"classificationError"); auto prediction = CNTK::ClassificationError(function, labels, L"classificationError");
auto learner = MomentumSGDLearner(function->Parameters(), lr, m); auto learner = MomentumSGDLearner(function->Parameters(), lr, m, /*unitGainMomentum = */true);
return Trainer(function, trainingLoss, prediction, { learner }); return Trainer(function, trainingLoss, prediction, { learner });
} }
@ -674,7 +674,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
LearningRatePerSampleSchedule learningRateSchedule2({ { 0.04, 0.02, 0.01, 0.008, 0.004, 0.002, 0.001 } }, actualMBSize); LearningRatePerSampleSchedule learningRateSchedule2({ { 0.04, 0.02, 0.01, 0.008, 0.004, 0.002, 0.001 } }, actualMBSize);
MomentumAsTimeConstantSchedule momentumSchedule({ { 900, 800, 700, 600, 500 } }, actualMBSize); MomentumAsTimeConstantSchedule momentumSchedule({ { 900, 800, 700, 600, 500 } }, actualMBSize);
auto learner2 = AdamLearner(classifierOutput->Parameters(), learningRateSchedule, momentumSchedule); auto learner2 = AdamLearner(classifierOutput->Parameters(), learningRateSchedule, momentumSchedule, /*unitGainMomentum = */true);
Trainer trainer2(classifierOutput, trainingLoss, prediction, { learner }); Trainer trainer2(classifierOutput, trainingLoss, prediction, { learner });

Просмотреть файл

@ -111,7 +111,7 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
LearningRatePerSampleSchedule learningRatePerSample = 0.000781; LearningRatePerSampleSchedule learningRatePerSample = 0.000781;
MomentumAsTimeConstantSchedule momentumTimeConstant = 6074; MomentumAsTimeConstantSchedule momentumTimeConstant = 6074;
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant); auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true);
Trainer trainer(classifierOutput, trainingLoss, prediction, {learner}); Trainer trainer(classifierOutput, trainingLoss, prediction, {learner});
size_t outputFrequencyInMinibatches = 1; size_t outputFrequencyInMinibatches = 1;

Просмотреть файл

@ -271,7 +271,7 @@
" # Feel free to try other optimizers from \n", " # Feel free to try other optimizers from \n",
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n", " # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
" learner = adam_sgd(model.parameters,\n", " learner = adam_sgd(model.parameters,\n",
" lr=lr_schedule, momentum=momentum_as_time_constant) \n", " lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True) \n",
" \n", " \n",
" # Instantiate the trainer\n", " # Instantiate the trainer\n",
" trainer = Trainer(model, loss, label_error, learner)\n", " trainer = Trainer(model, loss, label_error, learner)\n",

Просмотреть файл

@ -324,6 +324,7 @@
" # trainer object\n", " # trainer object\n",
" learner = momentum_sgd(z.parameters, \n", " learner = momentum_sgd(z.parameters, \n",
" lr = lr_per_minibatch, momentum = momentum_time_constant, \n", " lr = lr_per_minibatch, momentum = momentum_time_constant, \n",
" unit_gain = True, \n",
" l2_regularization_weight=l2_reg_weight)\n", " l2_regularization_weight=l2_reg_weight)\n",
" trainer = Trainer(z, ce, pe, [learner])\n", " trainer = Trainer(z, ce, pe, [learner])\n",
"\n", "\n",

Просмотреть файл

@ -392,6 +392,7 @@
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n", " # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
" learner = adam_sgd(criterion.parameters,\n", " learner = adam_sgd(criterion.parameters,\n",
" lr=lr_schedule, momentum=momentum_as_time_constant,\n", " lr=lr_schedule, momentum=momentum_as_time_constant,\n",
" unit_gain = True, \n",
" low_memory=True,\n", " low_memory=True,\n",
" gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)\n", " gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)\n",
"\n", "\n",
@ -497,7 +498,7 @@
" lr_schedule = learning_rate_schedule(1, UnitType.minibatch)\n", " lr_schedule = learning_rate_schedule(1, UnitType.minibatch)\n",
" momentum_as_time_constant = momentum_as_time_constant_schedule(0)\n", " momentum_as_time_constant = momentum_as_time_constant_schedule(0)\n",
" dummy_learner = adam_sgd(criterion.parameters, \n", " dummy_learner = adam_sgd(criterion.parameters, \n",
" lr=lr_schedule, momentum=momentum_as_time_constant, low_memory=True)\n", " lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True, low_memory=True)\n",
" evaluator = Trainer(model, criterion.outputs[0], criterion.outputs[1], dummy_learner)\n", " evaluator = Trainer(model, criterion.outputs[0], criterion.outputs[1], dummy_learner)\n",
" progress_printer = ProgressPrinter(tag='Evaluation')\n", " progress_printer = ProgressPrinter(tag='Evaluation')\n",
"\n", "\n",

Просмотреть файл

@ -753,6 +753,7 @@
"gradient_clipping_with_truncation = True\n", "gradient_clipping_with_truncation = True\n",
"learner = momentum_sgd(model.parameters,\n", "learner = momentum_sgd(model.parameters,\n",
" lr_per_sample, momentum_time_constant,\n", " lr_per_sample, momentum_time_constant,\n",
" unit_gain = True, \n",
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n", " gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n", " gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
"trainer = Trainer(model, ce, errs, learner)" "trainer = Trainer(model, ce, errs, learner)"
@ -919,6 +920,7 @@
" gradient_clipping_with_truncation = True\n", " gradient_clipping_with_truncation = True\n",
" learner = momentum_sgd(model.parameters,\n", " learner = momentum_sgd(model.parameters,\n",
" lr_per_sample, momentum_time_constant,\n", " lr_per_sample, momentum_time_constant,\n",
" unit_gain = True, \n",
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n", " gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n", " gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
" trainer = Trainer(model, ce, errs, learner)\n", " trainer = Trainer(model, ce, errs, learner)\n",

Просмотреть файл

@ -308,16 +308,19 @@
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
%ignore CNTK::MomentumSGDLearner(const std::vector<Parameter>& parameters, %ignore CNTK::MomentumSGDLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
%ignore CNTK::NesterovLearner(const std::vector<Parameter>& parameters, %ignore CNTK::NesterovLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
%ignore CNTK::DefaultVarianceMomentum; %ignore CNTK::DefaultVarianceMomentum;
%ignore CNTK::AdamLearner(const std::vector<Parameter>& parameters, %ignore CNTK::AdamLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule, const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule, const MomentumSchedule& momentumSchedule,
bool unitGain,
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum, const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
bool lowMemory = true, bool lowMemory = true,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions()); AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());

Просмотреть файл

@ -343,7 +343,7 @@ def sgd(parameters, lr,
return cntk_py.sgd_learner(parameters, lr, additional_options) return cntk_py.sgd_learner(parameters, lr, additional_options)
@typemap @typemap
def momentum_sgd(parameters, lr, momentum, def momentum_sgd(parameters, lr, momentum, unit_gain,
l1_regularization_weight=0.0, l2_regularization_weight=0.0, l1_regularization_weight=0.0, l2_regularization_weight=0.0,
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf, gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
gradient_clipping_with_truncation=True): gradient_clipping_with_truncation=True):
@ -358,6 +358,7 @@ def momentum_sgd(parameters, lr, momentum,
:func:`momentum_as_time_constant_schedule`): momentum schedule. :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_. <https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
l1_regularization_weight (float, optional): the L1 regularization weight per sample, l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0 defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample, l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -385,11 +386,11 @@ def momentum_sgd(parameters, lr, momentum,
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
return cntk_py.momentum_sgd_learner(parameters, lr, momentum, return cntk_py.momentum_sgd_learner(parameters, lr, momentum, unit_gain,
additional_options) additional_options)
@typemap @typemap
def nesterov(parameters, lr, momentum, def nesterov(parameters, lr, momentum, unit_gain,
l1_regularization_weight=0.0, l2_regularization_weight=0.0, l1_regularization_weight=0.0, l2_regularization_weight=0.0,
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf, gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
gradient_clipping_with_truncation=True): gradient_clipping_with_truncation=True):
@ -406,6 +407,7 @@ def nesterov(parameters, lr, momentum,
:func:`momentum_as_time_constant_schedule`): momentum schedule. :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_. <https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
l1_regularization_weight (float, optional): the L1 regularization weight per sample, l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0 defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample, l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -442,7 +444,7 @@ def nesterov(parameters, lr, momentum,
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
return cntk_py.nesterov_learner(parameters, lr, momentum, return cntk_py.nesterov_learner(parameters, lr, momentum, unit_gain,
additional_options) additional_options)
@typemap @typemap
@ -495,7 +497,7 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
# TODO: unCamelCase and integrate upcoming CR # TODO: unCamelCase and integrate upcoming CR
@typemap @typemap
def adam_sgd(parameters, lr, momentum, def adam_sgd(parameters, lr, momentum, unit_gain,
variance_momentum = momentum_as_time_constant_schedule(720000), variance_momentum = momentum_as_time_constant_schedule(720000),
low_memory=True, low_memory=True,
l1_regularization_weight=0.0, l2_regularization_weight=0.0, l1_regularization_weight=0.0, l2_regularization_weight=0.0,
@ -513,6 +515,7 @@ def adam_sgd(parameters, lr, momentum,
:func:`momentum_as_time_constant_schedule`): momentum schedule. :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_. <https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
variance_momentum (output of :func:`momentum_schedule` or variance_momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults :func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
to ``momentum_as_time_constant_schedule(720000)``. to ``momentum_as_time_constant_schedule(720000)``.
@ -551,7 +554,7 @@ def adam_sgd(parameters, lr, momentum,
additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample additional_options.gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation additional_options.gradient_clipping_with_truncation = gradient_clipping_with_truncation
return cntk_py.adam_learner(parameters, lr, momentum, return cntk_py.adam_learner(parameters, lr, momentum, unit_gain,
variance_momentum, low_memory, additional_options) variance_momentum, low_memory, additional_options)
@typemap @typemap

Просмотреть файл

@ -45,7 +45,7 @@ def run_distributed_training(tmpdir, create_func):
momentum_time_constant = momentum_as_time_constant_schedule(1100) momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample) lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
dist_learner = create_func(momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)) dist_learner = create_func(momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, True))
communicator = dist_learner.communicator() communicator = dist_learner.communicator()
workers = communicator.workers() workers = communicator.workers()

Просмотреть файл

@ -73,16 +73,16 @@ def test_learner_init():
momentum_time_constant = momentum_as_time_constant_schedule(1100) momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.1, UnitType.sample) lr_per_sample = learning_rate_schedule(0.1, UnitType.sample)
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant) momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, True)
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample) lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample)
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant) nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=False)
lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample) lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample)
adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True) adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True)
lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample) lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample)
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant) adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=True)
gamma, inc, dec, max, min = [0.1]*5 gamma, inc, dec, max, min = [0.1]*5
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100) lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)

Просмотреть файл

@ -29,7 +29,7 @@ def test_trainer(tmpdir):
momentum_time_constant = momentum_as_time_constant_schedule(1100) momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample) lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
trainer = Trainer(z, ce, errs, trainer = Trainer(z, ce, errs,
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)]) [momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, True)])
in1_value = [[1],[2]] in1_value = [[1],[2]]
label_value = [[0], [1]] label_value = [[0], [1]]
arguments = {in1: in1_value, labels: label_value} arguments = {in1: in1_value, labels: label_value}
@ -57,7 +57,7 @@ def test_output_to_retain():
momentum_time_constant = momentum_as_time_constant_schedule(1100) momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample) lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
trainer = Trainer(z, ce, errs, trainer = Trainer(z, ce, errs,
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)]) [momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, True)])
in1_value = [[[1]], [[2]]] in1_value = [[[1]], [[2]]]
label_value = [[0], [1]] label_value = [[0], [1]]
arguments = {in1: in1_value, labels: label_value} arguments = {in1: in1_value, labels: label_value}