Integrate alrezni/v2_scratch into master

This commit is contained in:
Project Philly 2016-11-14 10:59:42 -08:00
Родитель 37bc98867f 4ef4a91130
Коммит babf078b49
23 изменённых файлов: 247 добавлений и 227 удалений

Просмотреть файл

@ -124,7 +124,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
minibatch_size = 64
# Set learning parameters
lr_per_minibatch = learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], epoch_size, UnitType.minibatch)
lr_per_minibatch = learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], UnitType.minibatch, epoch_size)
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))
l2_reg_weight = 0.0001

Просмотреть файл

@ -10,7 +10,7 @@ import os
from cntk import Trainer
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
from cntk.device import cpu, set_default_device
from cntk.learner import sgd
from cntk.learner import sgd, learning_rate_schedule, UnitType
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, relu, element_times, constant
abs_path = os.path.dirname(os.path.abspath(__file__))
@ -66,8 +66,9 @@ def simple_mnist(debug_output=False):
label : reader_train.streams.labels
}
lr_per_minibatch=learning_rate_schedule(0.2, UnitType.minibatch)
# Instantiate the trainer object to drive the model training
trainer = Trainer(z, ce, pe, sgd(z.parameters, lr=1./320))
trainer = Trainer(z, ce, pe, sgd(z.parameters, lr=lr_per_minibatch))
# Get minibatches of images to train with and perform model training
minibatch_size = 64

Просмотреть файл

@ -170,7 +170,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
minibatch_size = 128
# Set learning parameters
lr_per_minibatch = learning_rate_schedule([1]*80 + [0.1]*40 + [0.01], epoch_size, UnitType.minibatch)
lr_per_minibatch = learning_rate_schedule([1]*80 + [0.1]*40 + [0.01], UnitType.minibatch, epoch_size)
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))
l2_reg_weight = 0.0001

Просмотреть файл

@ -68,7 +68,8 @@ def cifar_resnet_distributed(data_path, run_test, num_epochs, communicator=None,
num_mbs = num_mb_per_epoch * num_epochs
lr_per_minibatch = learning_rate_schedule([1]*80 + [0.1]*40 + [0.01], mb_size * num_mb_per_epoch, UnitType.minibatch)
lr_schedule = [1.0/mb_size]*80 + [0.1/mb_size]*40 + [0.01/mb_size]
lr_per_minibatch = learning_rate_schedule(lr_schedule, UnitType.minibatch, mb_size * num_mb_per_epoch)
momentum_time_constant = momentum_as_time_constant_schedule(-mb_size/np.log(0.9))
# create data parallel distributed trainer if needed

Просмотреть файл

@ -12,7 +12,7 @@ from cntk.models import * # higher abstraction level, e.g. entire standard mode
from cntk.utils import *
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
from cntk import Trainer
from cntk.learner import adam_sgd, learning_rate_schedule, momentum_as_time_constant_schedule
from cntk.learner import adam_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
from cntk.ops import cross_entropy_with_softmax, classification_error
########################
@ -76,10 +76,10 @@ def train(reader, model, max_epochs):
num_mbs_to_show_result = 100
momentum_time_constant = momentum_as_time_constant_schedule(minibatch_size / -math.log(0.9)) # TODO: Change to round number. This is 664.39. 700?
lr_schedule = [0.003]*2+[0.0015]*12+[0.0003] # LR schedule over epochs (we don't run that mayn epochs, but if we did, these are good values)
lr_schedule = [0.003]*2+[0.0015]*12+[0.0003] # LR schedule over epochs (we don't run that many epochs, but if we did, these are good values)
# trainer object
lr_per_sample = learning_rate_schedule(lr_schedule, epoch_size)
lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size)
learner = adam_sgd(z.parameters,
lr=lr_per_sample, momentum=momentum_time_constant,
low_memory=True,

Просмотреть файл

@ -9,7 +9,7 @@ import os
from cntk import Trainer, Axis #, text_format_minibatch_source, StreamConfiguration
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
from cntk.device import cpu, set_default_device
from cntk.learner import sgd
from cntk.learner import sgd, learning_rate_schedule, UnitType
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence
abs_path = os.path.dirname(os.path.abspath(__file__))
@ -62,9 +62,10 @@ def train_sequence_classifier(debug_output=False):
label : reader.streams.labels
}
lr_per_sample = learning_rate_schedule(0.0005, UnitType.sample)
# Instantiate the trainer object to drive the model training
trainer = Trainer(classifier_output, ce, pe,
sgd(classifier_output.parameters, lr=0.0005))
sgd(classifier_output.parameters, lr=lr_per_sample))
# Get minibatches of sequences to train with and perform model training
minibatch_size = 200

Просмотреть файл

@ -10,7 +10,7 @@ import os
from cntk import Trainer, Axis, save_model, load_model #, text_format_minibatch_source, StreamConfiguration
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
from cntk.device import cpu, set_default_device
from cntk.learner import momentum_sgd, momentum_as_time_constant_schedule
from cntk.learner import learning_rate_schedule, UnitType, momentum_sgd, momentum_as_time_constant_schedule
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, past_value, future_value, element_select, alias, hardmax
from cntk.ops.functions import CloneMethod
@ -151,13 +151,12 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
ng = z.clone(CloneMethod.share, {decoder_history_hook.output : net_output.output})
# Instantiate the trainer object to drive the model training
lr_per_sample = 0.007
minibatch_size = 72
lr_per_minibatch = learning_rate_schedule(0.5, UnitType.minibatch)
momentum_time_constant = momentum_as_time_constant_schedule(1100)
clipping_threshold_per_sample = 2.3
gradient_clipping_with_truncation = True
learner = momentum_sgd(z.parameters,
lr_per_sample, momentum_time_constant,
lr_per_minibatch, momentum_time_constant,
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
trainer = Trainer(z, ce, errs, learner)
@ -185,6 +184,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
# Get minibatches of sequences to train with and perform model training
i = 0
mbs = 0
minibatch_size = 72
epoch_size = 908241
max_epochs = 10
training_progress_output_freq = 500
@ -240,7 +240,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
ce = cross_entropy_with_softmax(z, label_sequence)
errs = classification_error(z, label_sequence)
trainer = Trainer(z, ce, errs, [momentum_sgd(
z.parameters, lr_per_sample, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
z.parameters, lr_per_minibatch, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
error2 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim)

Просмотреть файл

@ -2963,14 +2963,14 @@ namespace CNTK
///
/// Create a schedule with a constant parameter value.
///
CNTK_API TrainingParameterSchedule(T value, UnitType unit = UnitType::Sample);
CNTK_API TrainingParameterSchedule(T value, UnitType unit);
///
/// Create a schedule where the parameter changes its value every 'epochSize' samples:
/// schedule[0] is used for the first 'epochSize' samples, schedule[1] -- for the second,
/// and so on. The last value is then used repeatedly until the end of training.
///
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, size_t epochSize = 1, UnitType unit = UnitType::Sample);
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, UnitType unit, size_t epochSize = 1);
///
/// Create a schedule using the list of key-value pairs, where the key specifies
@ -2981,7 +2981,7 @@ namespace CNTK
/// the first 100 samples, then '0.1' is used for the second 200 samples,
/// after which the values is switched to '0.005'.
///
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, size_t epochSize = 1, UnitType unit = UnitType::Sample);
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize = 1);
///
/// Returns a value corresponding to the absolute sample (or sweep)
@ -3033,11 +3033,11 @@ namespace CNTK
{ }
TrainingParameterPerUnitSchedule(const std::vector<double>& schedule, size_t epochSize = 1)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, epochSize, U)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
{ }
TrainingParameterPerUnitSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = 1)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, epochSize, U)
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
{ }
#ifdef SWIG // for Python interop (adds indexer)
@ -3077,19 +3077,19 @@ namespace CNTK
{
public:
MomentumAsTimeConstantSchedule(double value)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(value)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(value, UnitType::Sample)
{
ConvertToPerSampleValues();
}
MomentumAsTimeConstantSchedule(const std::vector<double>& schedule, size_t epochSize = 1)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, epochSize)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
{
ConvertToPerSampleValues();
}
MomentumAsTimeConstantSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = 1)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, epochSize)
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
{
ConvertToPerSampleValues();
}
@ -3112,7 +3112,11 @@ namespace CNTK
{
double l1RegularizationWeight = 0.0;
double l2RegularizationWeight = 0.0;
TrainingParameterSchedule<double> gaussianNoiseInjectionStdDev = 0.0;
#ifdef SWIG //for python interop (swig does not fully support "using")
TrainingParameterPerUnitSchedule<double, TrainingParameterSchedule<double>::UnitType::Minibatch> gaussianNoiseInjectionStdDev = 0.0;
#else
TrainingParameterPerMinibatchSchedule<double> gaussianNoiseInjectionStdDev = 0.0;
#endif
double gradientClippingThresholdPerSample = std::numeric_limits<double>::infinity();
bool gradientClippingWithTruncation = true;
};
@ -3171,18 +3175,11 @@ namespace CNTK
virtual void ResetSmoothedGradients() = 0;
///
/// Returns current (per-sample) learning rate.
/// Returns current learning rate.
///
virtual double LearningRate(size_t minibatchSize = 1) const
virtual double LearningRate() const
{
auto learningRate = GetCurrentTrainingParameterValue<double>(m_learningRateSchedule);
if (m_learningRateSchedule.Unit() == LearningRateSchedule::UnitType::Minibatch)
{
// learning rate needs to be converted to the per-sample value.
return (minibatchSize == 0) ? 0.0 : learningRate / minibatchSize;
}
return learningRate;
return GetCurrentTrainingParameterValue<double>(m_learningRateSchedule);
}
protected:

Просмотреть файл

@ -41,6 +41,19 @@ namespace CNTK
std::string LearnerType() const;
// Returns current (per-sample) learning rate.
double LearningRate(size_t minibatchSize) const
{
auto learningRate = Learner::LearningRate();
if (m_learningRateSchedule.Unit() == LearningRateSchedule::UnitType::Minibatch)
{
// learning rate needs to be converted to the per-sample value.
return (minibatchSize == 0) ? 0.0 : learningRate / minibatchSize;
}
return learningRate;
}
AdditionalLearningOptions m_additionalOptions;
std::unordered_map<Parameter, NDArrayViewPtr> m_smoothedGradientValues;

Просмотреть файл

@ -244,7 +244,7 @@ namespace CNTK
}
template <typename T>
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<T>& schedule, size_t epochSize, UnitType unit)
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<T>& schedule, UnitType unit, size_t epochSize)
: m_unit(unit), m_epochSize(epochSize)
{
std::vector<std::pair<size_t, T>> s(schedule.size());
@ -257,7 +257,7 @@ namespace CNTK
}
template <typename T>
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<std::pair<size_t, T>>& schedule, size_t epochSize, UnitType unit)
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize)
: m_unit(unit), m_epochSize(epochSize)
{
ConstructSchedule(schedule);

Просмотреть файл

@ -50,7 +50,7 @@ void TrainSimpleDistributedFeedForwardClassifer(const DeviceDescriptor& device,
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");;
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
double learningRatePerSample = 0.02;
auto learningRatePerSample = LearningRatePerSampleSchedule(0.02);
minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) }, distributedTrainer);
size_t outputFrequencyInMinibatches = 20;

Просмотреть файл

@ -157,7 +157,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
classifierOutput = classifierOutputVar;
}
double learningRatePerSample = 0.0078125;
LearningRatePerSampleSchedule learningRatePerSample = 0.0078125;
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
const size_t minibatchSize = 32;

Просмотреть файл

@ -51,7 +51,7 @@ void TestSGDLearner(size_t numParameters, size_t numMinibatches, const DeviceDes
{
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
auto learner = SGDLearner(parameters, 0.4);
auto learner = SGDLearner(parameters, LearningRatePerSampleSchedule(0.4));
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
}
@ -61,10 +61,10 @@ void TestMomentumSGDLearner(size_t numParameters, size_t numMinibatches, const D
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
LearningRatePerMinibatchSchedule learnigRateSchedule = { { 3.0, 2.0, 1.0 }, numMinibatches };
MomentumSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 };
MomentumPerSampleSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 };
auto learner = MomentumSGDLearner(parameters, learnigRateSchedule, momentumValues);
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
FloatingPointCompare(learner->LearningRate(100), 0.02, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 2.0, "Learner::LearningRate does not match expectation");
}
template <typename ElementType>
@ -91,7 +91,7 @@ void TestFSAdaGradLearner(size_t numParameters, size_t numMinibatches, const Dev
{
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
auto learner = AdamLearner(parameters, { { 0.5 } }, MomentumAsTimeConstantSchedule({ 10, 100, 1000 }));
auto learner = AdamLearner(parameters, LearningRatePerSampleSchedule({ 0.5 }), MomentumAsTimeConstantSchedule({ 10.0, 100.0, 1000.0 }));
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
}
@ -100,7 +100,7 @@ void TestRMSPropLearner(size_t numParameters, size_t numMinibatches, const Devic
{
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
auto learner = RMSPropLearner(parameters, { { { 3, 0.7 }, { 1, 0.2 } } }, 0.01, 0.02, 0.03, 0.1, 0.001);
auto learner = RMSPropLearner(parameters, LearningRatePerMinibatchSchedule({ { 3, 0.7 }, { 1, 0.2 } }), 0.01, 0.02, 0.03, 0.1, 0.001);
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
}
@ -110,19 +110,19 @@ void TestTrainingParametersSchedule()
LearningRatePerMinibatchSchedule({ 3.0, 2.0, 1.0 }, LearningRateSchedule::EntireSweep);
}, "Was able to create not-yet-implemented sweep-based schedule.");
LearningRateSchedule schedule1 = 0.5;
LearningRatePerSampleSchedule schedule1 = 0.5;
assert(schedule1.Unit() == LearningRateSchedule::UnitType::Sample);
assert(schedule1[0] == 0.5);
assert(schedule1[1] == 0.5);
assert(schedule1[100] == 0.5);
LearningRateSchedule schedule2 = { 0.5 };
LearningRatePerSampleSchedule schedule2 = { 0.5 };
assert(schedule2.Unit() == LearningRateSchedule::UnitType::Sample);
assert(schedule2[0] == 0.5);
assert(schedule2[10] == 0.5);
assert(schedule2[100] == 0.5);
LearningRateSchedule schedule3 = { { 0.5, 0.3, 0.3 } };
LearningRatePerSampleSchedule schedule3({ 0.5, 0.3, 0.3 });
assert(schedule3.Unit() == LearningRateSchedule::UnitType::Sample);
assert(schedule3[0] == 0.5);
assert(schedule3[1] == 0.3);
@ -143,8 +143,8 @@ void TestTrainingParametersSchedule()
assert(schedule5[20] == 0.2);
assert(schedule5[100] == 0.2);
MomentumSchedule schedule6 = { { make_pair(1, 0.5) } }; // without make_pair this is interpreted as a vector of doubles
assert(schedule6.Unit() == MomentumSchedule::UnitType::Sample);
MomentumPerMinibatchSchedule schedule6 = { { make_pair(1, 0.5) } }; // without make_pair this is interpreted as a vector of doubles
assert(schedule6.Unit() == MomentumSchedule::UnitType::Minibatch);
assert(schedule6[0] == 0.5);
assert(schedule6[10] == 0.5);
assert(schedule6[100] == 0.5);
@ -165,7 +165,7 @@ void TestTrainingParametersSchedule()
assert(schedule8[20] == 0.2);
assert(schedule8[100] == 0.2);
LearningRateSchedule schedule9 = { { { 3, 0.5 }, { 2, 0.3 }, { 1, 0.2 } } };
LearningRatePerSampleSchedule schedule9 = { { { 3, 0.5 }, { 2, 0.3 }, { 1, 0.2 } } };
assert(schedule9.Unit() == LearningRateSchedule::UnitType::Sample);
assert(schedule9[0] == 0.5);
assert(schedule9[2] == 0.5);

Просмотреть файл

@ -175,13 +175,12 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
auto rawInputStreamInfo = minibatchSource->StreamInfo(featureStreamName);
auto rawLabelsStreamInfo = minibatchSource->StreamInfo(labelStreamName);
double learningRatePerSample = 0.007;
size_t momentumTimeConstant = 1100;
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
LearningRatePerSampleSchedule learningRatePerSample = 0.007;
MomentumAsTimeConstantSchedule momentumTimeConstant = 1100;
AdditionalLearningOptions additionalOptions;
additionalOptions.gradientClippingThresholdPerSample = 2.3;
additionalOptions.gradientClippingWithTruncation = true;
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumPerSample, additionalOptions) });
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumTimeConstant, additionalOptions) });
size_t outputFrequencyInMinibatches = 1;
size_t minibatchSize1 = 72;

Просмотреть файл

@ -46,10 +46,9 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool useSparseLa
auto featureStreamInfo = minibatchSource->StreamInfo(featuresName);
auto labelStreamInfo = minibatchSource->StreamInfo(labelsName);
double learningRatePerSample = 0.0005;
size_t momentumTimeConstant = 256;
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
Trainer trainer(classifierOutput, trainingLoss, prediction, { MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumPerSample) });
LearningRatePerSampleSchedule learningRatePerSample = 0.0005;
MomentumAsTimeConstantSchedule momentumTimeConstant = 256;
Trainer trainer(classifierOutput, trainingLoss, prediction, { MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant) });
size_t outputFrequencyInMinibatches = 1;
for (size_t i = 0; true; i++)
@ -91,68 +90,68 @@ void TestLearningRateControl(const DeviceDescriptor& device)
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
Trainer trainer(classifierOutput, trainingLoss, prediction, { learner });
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
const wchar_t* modelFile = L"seq2seq.model";
trainer.SaveCheckpoint(modelFile);
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
auto MB2Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.00025, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
auto MB3Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.00025, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
trainer.RestoreFromCheckpoint(modelFile);
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
auto postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.00025, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
auto postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
trainer.RestoreFromCheckpoint(modelFile);
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
learner->ResetLearningRate(0.0004);
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
learner->ResetLearningRate(LearningRatePerSampleSchedule(0.0004));
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer.SaveCheckpoint(modelFile);
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer.RestoreFromCheckpoint(modelFile);
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
}
void TrainLSTMSequenceClassifer()

Просмотреть файл

@ -203,7 +203,7 @@ void TestLearnerSerialization(int numParameters, const DeviceDescriptor& device)
gradientValues[parameter] = NDArrayView::RandomUniform<ElementType>(shape, -0.5, 0.5, numParameters + i, device);
}
auto learner1 = SGDLearner(parameters, 0.05);
auto learner1 = SGDLearner(parameters, LearningRatePerSampleSchedule(0.05));
learner1->Update(gradientValues, 1);
@ -215,7 +215,7 @@ void TestLearnerSerialization(int numParameters, const DeviceDescriptor& device)
stream.flush();
}
auto learner2 = SGDLearner(parameters, 0.05);
auto learner2 = SGDLearner(parameters, LearningRatePerSampleSchedule( 0.05));
{
Dictionary checkpoint;
@ -405,7 +405,9 @@ void TestFunctionSerialization(const DeviceDescriptor& device)
TestFunctionSaveAndLoad(BuildLSTMClassifierNet(inputVar, 5, device), device);
}
Trainer BuildTrainer(const FunctionPtr& function, const Variable& labels, LearningRateSchedule lr = 0.005, MomentumSchedule m = 0.0)
Trainer BuildTrainer(const FunctionPtr& function, const Variable& labels,
LearningRateSchedule lr = LearningRatePerSampleSchedule(0.005),
MomentumSchedule m = MomentumAsTimeConstantSchedule(0.0))
{
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(function, labels, L"lossFunction");
auto prediction = CNTK::ClassificationError(function, labels, L"classificationError");
@ -498,7 +500,7 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
LearningRateSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize);
@ -633,7 +635,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
LearningRateSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
Trainer trainer(classifierOutput, trainingLoss, prediction, { learner });

Просмотреть файл

@ -59,7 +59,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
prediction = predictionVar;
}
double learningRatePerSample = 0.02;
LearningRatePerSampleSchedule learningRatePerSample = 0.02;
minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
size_t outputFrequencyInMinibatches = 20;
@ -121,7 +121,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
auto featureStreamInfo = minibatchSource->StreamInfo(featureStreamName);
auto labelStreamInfo = minibatchSource->StreamInfo(labelsStreamName);
double learningRatePerSample = 0.003125;
LearningRatePerSampleSchedule learningRatePerSample = 0.003125;
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
size_t outputFrequencyInMinibatches = 20;

Просмотреть файл

@ -109,10 +109,9 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
featureStreamInfo = minibatchSource->StreamInfo(features);
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
double learningRatePerSample = 0.000781;
size_t momentumTimeConstant = 6074;
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumPerSample);
LearningRatePerSampleSchedule learningRatePerSample = 0.000781;
MomentumAsTimeConstantSchedule momentumTimeConstant = 6074;
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant);
Trainer trainer(classifierOutput, trainingLoss, prediction, {learner});
size_t outputFrequencyInMinibatches = 1;

Просмотреть файл

@ -9,7 +9,7 @@ import sys
import os
from cntk.device import cpu, set_default_device
from cntk import Trainer
from cntk.learner import sgd
from cntk.learner import sgd, learning_rate_schedule, UnitType
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sigmoid
from cntk.utils import ProgressPrinter
@ -52,8 +52,9 @@ def ffnet():
ce = cross_entropy_with_softmax(netout, label)
pe = classification_error(netout, label)
lr_per_minibatch=learning_rate_schedule(0.5, UnitType.minibatch)
# Instantiate the trainer object to drive the model training
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=0.02))
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=lr_per_minibatch))
# Get minibatches of training data and perform model training
minibatch_size = 25

Просмотреть файл

@ -50,6 +50,33 @@ the following learning algorithms:
+------------------------+
'''
# an internal method to verify that the learning rate schedule
# has a proper (per-sample or per-MB schedule) type and raise
# an exception otherwise
def _verify_learning_rate_type(learning_rate):
if not isinstance(learning_rate,
(cntk_py.training_parameter_per_sample_schedule,
cntk_py.training_parameter_per_minibatch_schedule)):
raise ValueError('learning_rate type (%s) not supported. '
'learning_rate must be a training schedule '
'(output of learning_rate_schedule() function)'
% type(learning_rate))
# an internal method to verify that the mometum schedule
# has a proper (per-MB or time-constant schedule) type and raise
# an exception otherwise
def _verify_momentum_type(momentum):
if not isinstance(momentum,
(cntk_py.training_parameter_per_minibatch_schedule,
cntk_py.momentum_as_time_constant_schedule)):
raise ValueError('momentum type (%s) not supported. '
'momentum must be a training schedule '
'(output of momentum_schedule() or '
'momentum_as_time_constant_schedule() function)'
% type(momentum))
class Learner(cntk_py.Learner):
'''
Abstraction for learning a subset of parameters of a learnable function using first order gradient values
@ -63,7 +90,7 @@ class Learner(cntk_py.Learner):
Update the parameters associated with this learner.
Args:
gradient_values (`dict`): maps :class:`~cntk.variables.Parameter` to
gradient_values (dict): maps :class:`~cntk.variables.Parameter` to
a NumPy array containing the first order gradient values for the
Parameter w.r.t. the training objective.
training_sample_count (int): training sample count
@ -85,49 +112,43 @@ class Learner(cntk_py.Learner):
'''
return super(Learner, self).parameters()
def reset_learning_rate(self, learning_rate):
'''
Resets the learning rate.
Args:
learning_rate (float, list or a training schedule): learning rate
to reset to
learning_rate (output of :func:`learning_rate_schedule`)
learning rate to reset to
'''
learning_rate = learning_rate_schedule(learning_rate)
_verify_learning_rate_type(learning_rate)
return super(Learner, self).reset_learning_rate(learning_rate)
def learning_rate(self, minibatch_size=1):
def learning_rate(self):
'''
The learning rate.
Args:
minibatch_size (int): minibatch size to re-scaled
the learning rate to the per-sample value (in case when the schedule
was build with ``unit=UnitType.minibatch``).
Current learning rate.
'''
return super(Learner, self).learning_rate(minibatch_size)
return super(Learner, self).learning_rate()
@typemap
def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
def training_parameter_schedule(schedule, unit, epoch_size=1):
'''
Create a training parameter schedule containing either per-sample (default)
or per-minibatch values.
Examples:
>>> # Use a fixed value 0.01 for all samples
>>> s = training_parameter_schedule(0.01)
>>> s = training_parameter_schedule(0.01, UnitType.sample)
>>> s[0], s[1]
(0.01, 0.01)
>>> # Use 0.01 for the first 1000 samples, then 0.001 for the remaining ones
>>> s = training_parameter_schedule([0.01, 0.001], 1000)
>>> s = training_parameter_schedule([0.01, 0.001], UnitType.sample, 1000)
>>> s[0], s[1], s[1000], s[1001]
(0.01, 0.01, 0.001, 0.001)
>>> # Use 0.1 for the first 12 epochs, then 0.01 for the next 15,
>>> # followed by 0.001 for the remaining ones, with a 100 samples in an epoch
>>> s = training_parameter_schedule([(12, 0.1), (15, 0.01), (1, 0.001)], 100)
>>> s = training_parameter_schedule([(12, 0.1), (15, 0.01), (1, 0.001)], UnitType.sample, 100)
>>> s[0], s[1199], s[1200], s[2699], s[2700], s[5000]
(0.1, 0.1, 0.01, 0.01, 0.001, 0.001)
@ -136,12 +157,11 @@ def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
for all samples. In case of list, the elements are used as the
values for ``epoch_size`` samples. If list contains pair, the second element is
used as a value for (``epoch_size`` x first element) samples
unit (:class:`UnitType`): one of two
* ``sample``: the returned schedule contains per-sample values
* ``minibatch``: the returned schedule contains per-minibatch values.
epoch_size (int): number of samples as a scheduling unit. Parameters in
the schedule change their values every ``epoch_size`` samples.
unit (:class:`UnitType`): one of two
* ``sample``: the returned schedule contains per-sample values (default)
* ``minibatch``: the returned schedule contains per-minibatch values.
Returns:
training parameter schedule
@ -153,9 +173,11 @@ def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
raise ValueError('schedule unit "%s" is not supported' %
str(method))
if isinstance(schedule, (cntk_py.training_parameter_per_sample_schedule,
cntk_py.training_parameter_per_minibatch_schedule,
cntk_py.momentum_as_time_constant_schedule)):
if unit == UnitType.sample:
if isinstance(schedule, cntk_py.training_parameter_per_sample_schedule):
return schedule
else:
if isinstance(schedule, cntk_py.training_parameter_per_minibatch_schedule):
return schedule
if isinstance(schedule, (int, float)):
@ -173,7 +195,7 @@ def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
raise ValueError('schedule must be either a float or a list, not %s'%type(schedule))
@typemap
def learning_rate_schedule(lr, epoch_size=1, unit=UnitType.sample):
def learning_rate_schedule(lr, unit, epoch_size=1):
'''
Create a learning rate schedule (using the same semantics as
:func:`training_parameter_schedule`).
@ -181,10 +203,10 @@ def learning_rate_schedule(lr, epoch_size=1, unit=UnitType.sample):
Args:
lr (float or list): see parameter ``schedule`` in
:func:`training_parameter_schedule`.
epoch_size (int): see parameter ``epoch_size`` in
:func:`training_parameter_schedule`.
unit (:class:`UnitType`): see parameter
``unit`` in :func:`training_parameter_schedule`.
epoch_size (int): see parameter ``epoch_size`` in
:func:`training_parameter_schedule`.
Returns:
learning rate schedule
@ -192,23 +214,21 @@ def learning_rate_schedule(lr, epoch_size=1, unit=UnitType.sample):
See also:
:func:`training_parameter_schedule`
'''
return training_parameter_schedule(lr, epoch_size, unit)
return training_parameter_schedule(lr, unit, epoch_size)
@typemap
def momentum_schedule(momentum, epoch_size=1, unit=UnitType.sample):
def momentum_schedule(momentum, epoch_size=1):
'''
Create a momentum schedule (using the same semantics as
:func:`training_parameter_schedule`).
Create a per-minibatch momentum schedule (using the same semantics as
:func:`training_parameter_schedule` with the `unit=UnitType.minibatch`).
Args:
momentum (float or list): see parameter ``schedule`` in
:func:`training_parameter_schedule`.
epoch_size (int): see parameter ``epoch_size`` in
:func:`training_parameter_schedule`.
unit (:class:`UnitType`): see parameter
``unit`` in :func:`training_parameter_schedule`.
If you want to provide momentum values in a sample/minibatch
If you want to provide momentum values in a minibatch-size
agnostic way, use :func:`momentum_as_time_constant_schedule`.
Examples:
@ -228,32 +248,23 @@ def momentum_schedule(momentum, epoch_size=1, unit=UnitType.sample):
>>> m[0], m[998], m[999], m[999+888-1], m[999+888]
(0.99, 0.99, 0.88, 0.88, 0.77)
Args:
momentum (float or list): see parameter ``schedule`` in
:func:`training_parameter_schedule`.
epoch_size (int): see parameter ``epoch_size`` in
:func:`training_parameter_schedule`.
unit (:class:`UnitType`): see parameter
``unit`` in :func:`training_parameter_schedule`.
Returns:
momentum schedule
'''
return training_parameter_schedule(momentum, epoch_size, unit)
return training_parameter_schedule(momentum, UnitType.minibatch, epoch_size)
@typemap
def momentum_as_time_constant_schedule(momentum, epoch_size=1):
'''
Create a momentum schedule in a minibatch agnostic way (using the same
semantics as :func:`training_parameter_schedule`).
Create a momentum schedule in a minibatch-size agnostic way
(using the same semantics as :func:`training_parameter_schedule`
with `unit=UnitType.sample`).
Args:
momentum (float or list): see parameter ``schedule`` in
:func:`training_parameter_schedule`.
epoch_size (int): see parameter ``epoch_size`` in
:func:`training_parameter_schedule`.
unit (:class:`UnitType`): see parameter
``unit`` in :func:`training_parameter_schedule`.
CNTK specifies momentum in a minibatch-size agnostic way as the time
constant (in samples) of a unit-gain 1st-order IIR filter. The value
@ -263,7 +274,6 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
If you want to specify the momentum per sample (or per minibatch),
use :func:`momentum_schedule`.
Examples:
>>> # Use a fixed momentum of 1100 for all samples
>>> m = momentum_as_time_constant_schedule(1100)
@ -272,18 +282,10 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
>>> # then 1500 for the remaining ones
>>> m = momentum_as_time_constant_schedule([1100, 1500], 1000)
Args:
momentum (float or list): see parameter ``schedule`` in
:func:`training_parameter_schedule`.
epoch_size (int): see parameter ``epoch_size`` in
:func:`training_parameter_schedule`.
Returns:
momentum as time constant schedule
'''
if isinstance(momentum, (cntk_py.training_parameter_per_sample_schedule,
cntk_py.training_parameter_per_minibatch_schedule,
cntk_py.momentum_as_time_constant_schedule)):
if isinstance(momentum, (cntk_py.momentum_as_time_constant_schedule)):
return momentum
if isinstance(momentum, (int, float)):
@ -293,7 +295,6 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
raise ValueError('momentum must be either a float or a list, not %s'%type(momentum))
# TODO figure out how to pass infty to C++ in a portable way
@typemap
def sgd(parameters, lr,
@ -308,9 +309,7 @@ def sgd(parameters, lr,
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the '.parameters()' method of the root
operator.
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
schedule. When the argument value is a float or a list, lr is
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -319,7 +318,8 @@ def sgd(parameters, lr,
of the Gaussian noise added to parameters post update, defaults to 0.0
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
per sample, defaults to infinity
gradient_clipping_with_truncation (bool, default ``True``): gradient clipping
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
with truncation
Returns:
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
@ -329,8 +329,9 @@ def sgd(parameters, lr,
<http://research.microsoft.com/pubs/192769/tricks-2012.pdf>`_. Neural
Networks: Tricks of the Trade: Springer, 2012.
'''
lr = learning_rate_schedule(lr)
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
_verify_learning_rate_type(lr)
gaussian_noise_injection_std_dev = \
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
additional_options = cntk_py.AdditionalLearningOptions()
additional_options.l1_regularization_weight = l1_regularization_weight
@ -347,17 +348,15 @@ def momentum_sgd(parameters, lr, momentum,
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=1E10,
gradient_clipping_with_truncation=True):
'''
Creates a Momemtum SGD learner instance to learn the parameters.
Creates a Momentum SGD learner instance to learn the parameters.
Args:
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (float, list```` or output of :func:`learning_rate_schedule`): learning rate
schedule. When the argument value is a float or a list, lr is
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule. When the argument
value is a float or a list, momentum is converted to a per-sample schedule by
invoking :func:`momentum_schedule`. Refer to the `wiki
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
@ -367,14 +366,16 @@ def momentum_sgd(parameters, lr, momentum,
of the Gaussian noise added to parameters post update, defaults to 0.0
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
per sample, defaults to infinity
gradient_clipping_with_truncation (bool, default ``True``): gradient clipping
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
with truncation
Returns:
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
Instance of a :class:`cntk.learner.Learner` that can be passed to the :class:`cntk.trainer.Trainer`
'''
lr = learning_rate_schedule(lr)
momentum = momentum_schedule(momentum)
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
_verify_learning_rate_type(lr)
_verify_momentum_type(momentum)
gaussian_noise_injection_std_dev = \
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
additional_options = cntk_py.AdditionalLearningOptions()
additional_options.l1_regularization_weight = l1_regularization_weight
@ -399,12 +400,10 @@ def nesterov(parameters, lr, momentum,
Args:
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
schedule. When the argument value is a float or a list, lr is
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule. When the argument
value is a float or a list, momentum is converted to a per-sample schedule by
invoking :func:`momentum_schedule`. Refer to the `wiki
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
@ -414,7 +413,8 @@ def nesterov(parameters, lr, momentum,
of the Gaussian noise added to parameters post update, defaults to 0.0
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
per sample, defaults to infinity
gradient_clipping_with_truncation (bool, default ``True``): gradient clipping
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
with truncation
Returns:
Instance of a :class:`~cntk.learner.Learner` that can be passed to the
@ -429,9 +429,10 @@ def nesterov(parameters, lr, momentum,
of the 30th International Conference on Machine Learning, 2013.
'''
lr = learning_rate_schedule(lr)
momentum = momentum_schedule(momentum)
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
_verify_learning_rate_type(lr)
_verify_momentum_type(momentum)
gaussian_noise_injection_std_dev = \
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
additional_options = cntk_py.AdditionalLearningOptions()
additional_options.l1_regularization_weight = l1_regularization_weight
@ -455,9 +456,7 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
Args:
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
schedule. When the argument value is a float or a list, lr is
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
need_ave_multiplier (bool, default):
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
@ -467,7 +466,8 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
of the Gaussian noise added to parameters post update, defaults to 0.0
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
per sample, defaults to infinity
gradient_clipping_with_truncation (bool, default `True`): gradient clipping
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
with truncation
Returns:
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
@ -478,8 +478,9 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
<http://www.magicbroom.info/Papers/DuchiHaSi10.pdf>`_. The Journal of
Machine Learning Research, 2011.
'''
lr = learning_rate_schedule(lr)
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
_verify_learning_rate_type(lr)
gaussian_noise_injection_std_dev = \
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
additional_options = cntk_py.AdditionalLearningOptions()
additional_options.l1_regularization_weight = l1_regularization_weight
@ -506,16 +507,14 @@ def adam_sgd(parameters, lr, momentum,
Args:
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
schedule. When the argument value is a float or a list, lr is
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule. When the argument
value is a float or a list, momentum is converted to a per-sample schedule by
invoking :func:`momentum_schedule`. Refer to the `wiki
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
variance_momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): variance momentum schedule. When the argument
value is a float or a list, variance momentum is converted to a per-sample schedule by
invoking :func:`momentum_schedule`. Defaults to momentum_as_time_constant_schedule(720000).
variance_momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
to ``momentum_as_time_constant_schedule(720000)``.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -524,7 +523,8 @@ def adam_sgd(parameters, lr, momentum,
of the Gaussian noise added to parameters post update, defaults to 0.0
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
per sample, defaults to infinity
gradient_clipping_with_truncation (bool, default `True`): gradient clipping
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
with truncation
Returns:
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
@ -537,10 +537,11 @@ def adam_sgd(parameters, lr, momentum,
if not low_memory:
raise NotImplementedError('adam: low_memory=True currently required')
lr = learning_rate_schedule(lr)
momentum = momentum_schedule(momentum)
variance_momentum = momentum_schedule(variance_momentum)
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
_verify_learning_rate_type(lr)
_verify_momentum_type(momentum)
_verify_momentum_type(variance_momentum)
gaussian_noise_injection_std_dev = \
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
additional_options = cntk_py.AdditionalLearningOptions()
additional_options.l1_regularization_weight = l1_regularization_weight
@ -565,15 +566,13 @@ def rmsprop(parameters, lr,
Args:
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
schedule. When the argument value is a float or a list, lr is
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
gamma (float):
inc (float):
dec (float):
max (float):
min (float):
need_ave_multiplier (bool, default):
need_ave_multiplier (bool, default ``True``):
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -582,13 +581,15 @@ def rmsprop(parameters, lr,
of the Gaussian noise added to parameters post update, defaults to 0.0
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
per sample, defaults to infinity
gradient_clipping_with_truncation (bool, default `True`): gradient clipping
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
with truncation
Returns:
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
'''
lr = learning_rate_schedule(lr)
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
_verify_learning_rate_type(lr)
gaussian_noise_injection_std_dev = \
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
additional_options = cntk_py.AdditionalLearningOptions()
additional_options.l1_regularization_weight = l1_regularization_weight

Просмотреть файл

@ -40,9 +40,9 @@ def run_distributed_trainer(tmpdir, quantized):
dist_trainer = distributed.data_parallel_distributed_trainer(communicator, False)
momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
trainer = Trainer(z, ce, errs, \
momentum_sgd(z.parameters, 0.007, momentum_time_constant),
momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant),
distributed_trainer=dist_trainer)
in1_value = [[1],[2]]
label_value = [[0], [1]]

Просмотреть файл

@ -10,20 +10,28 @@ from .. import parameter, input_variable
import pytest
SCHEDULE_PARAMS = [
LR_SCHEDULE_PARAMS = [
((0.2, UnitType.sample), [0.2]),
((0.2, UnitType.sample), [0.2, 0.2, 0.2, 0.2]),
(([0.2,0.4], UnitType.sample, 5), [0.2]*5+[0.4]*20),
(([(3,0.2),(2,0.4),(1,0.8)], UnitType.sample, 5), [0.2]*15+[0.4]*10+[0.8]*20),
]
MOMENTUM_SCHEDULE_PARAMS = [
((0.2,), [0.2]),
((0.2,), [0.2, 0.2, 0.2, 0.2]),
(([0.2,0.4], 5), [0.2]*5+[0.4]*20),
(([(3,0.2),(2,0.4),(1,0.8)], 5), [0.2]*15+[0.4]*10+[0.8]*20),
]
@pytest.mark.parametrize("params, expectation", SCHEDULE_PARAMS)
@pytest.mark.parametrize("params, expectation", LR_SCHEDULE_PARAMS)
def test_learning_rate_schedule(params, expectation):
l = learning_rate_schedule(*params)
assert [l[i] for i in range(len(expectation))] == expectation
def sweep_based_schedule_fails():
with pytest.raises(Exception):
learning_rate_schedule([1], epoch_size=0)
learning_rate_schedule([1], unit=UnitType.sample, epoch_size=0)
def test_momentum_schedule():
m = 2500
@ -38,7 +46,7 @@ def test_momentum_schedule():
expected = np.exp(-1.0 / np.asarray(mlist))
assert all(mi == ei for mi,ei in zip(msl,expected))
@pytest.mark.parametrize("params, expectation", SCHEDULE_PARAMS)
@pytest.mark.parametrize("params, expectation", MOMENTUM_SCHEDULE_PARAMS)
def test_momentum_schedule_per_sample(params, expectation):
l = momentum_schedule(*params)
assert [l[i] for i in range(len(expectation))] == expectation
@ -51,16 +59,11 @@ def test_learner_init():
res = i * w
learner = sgd(res.parameters, lr=learning_rate_schedule(0.1, 10000))
#per-sample learning rate does not depend on the minibatch size
learner = sgd(res.parameters, lr=learning_rate_schedule(0.1, UnitType.sample, 10000))
assert learner.learning_rate() == 0.1
assert learner.learning_rate(0) == 0.1
assert learner.learning_rate(10) == 0.1
learner.reset_learning_rate(learning_rate_schedule([1,2,3], unit=UnitType.minibatch));
assert learner.learning_rate(100) == 0.01
assert learner.learning_rate(0) == 0.0
learner.reset_learning_rate(learning_rate_schedule([1,2,3], UnitType.minibatch));
assert learner.learning_rate() == 1.0
learner_parameter = learner.parameters
from ..ops.variables import Parameter
@ -68,18 +71,21 @@ def test_learner_init():
assert isinstance(param, Parameter)
momentum_time_constant = momentum_as_time_constant_schedule(1100)
momentum_sgd(res.parameters, 0.1, momentum_time_constant)
lr_per_sample = learning_rate_schedule(0.1, UnitType.sample)
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant)
momentum_time_constant = momentum_schedule(momentum_time_constant) #should be ignored
nesterov(res.parameters, lr=[0.1, 0.2], momentum=momentum_time_constant)
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample)
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
adagrad(res.parameters, lr=[0.1]*3 +[0.2]*2 +[0.3], need_ave_multiplier=True)
lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample)
adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True)
momentum_time_constant = momentum_schedule(momentum_time_constant, unit=UnitType.minibatch) #should be ignored
adam_sgd(res.parameters, lr=[(3,0.1), (2, 0.2), (1, 0.3)], momentum=momentum_time_constant)
lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample)
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
gamma, inc, dec, max, min = [0.1]*5
rmsprop(res.parameters, learning_rate_schedule([0.1, 0.2], 100), gamma, inc, dec, max, min, True)
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)
rmsprop(res.parameters, lr_per_sample, gamma, inc, dec, max, min, True)
def test_learner_update():
i = input_variable(shape=(1,),
@ -89,7 +95,7 @@ def test_learner_update():
w = parameter(shape=(1,), init=w_init)
res = i * w
learner = sgd(res.parameters, lr=[0.1]*50 + [0.2]*50)
learner = sgd(res.parameters, lr=learning_rate_schedule([0.1]*50 + [0.2]*50, UnitType.sample))
assert learner.learning_rate() == 0.1
x = learner.update({w: np.asarray([[2.]], dtype=np.float32)}, 100)
assert learner.learning_rate() == 0.2

Просмотреть файл

@ -20,10 +20,10 @@ def test_trainer(tmpdir):
ce = cross_entropy_with_softmax(z, labels)
errs = classification_error(z, labels)
m_schedule = momentum_schedule(1100)
momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
trainer = Trainer(z, ce, errs, \
[momentum_sgd(z.parameters, 0.007, m_schedule)])
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)])
in1_value = [[1],[2]]
label_value = [[0], [1]]
arguments = {in1: in1_value, labels: label_value}
@ -49,10 +49,10 @@ def test_output_to_retain():
ce = cross_entropy_with_softmax(z, labels)
errs = classification_error(z, labels)
m_schedule = momentum_schedule(1100)
momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
trainer = Trainer(z, ce, errs, \
[momentum_sgd(z.parameters, 0.007, m_schedule)])
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)])
in1_value = [[1],[2]]
label_value = [[0], [1]]
arguments = {in1: in1_value, labels: label_value}