Integrate alrezni/v2_scratch into master
This commit is contained in:
Коммит
babf078b49
|
@ -124,7 +124,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
|
|||
minibatch_size = 64
|
||||
|
||||
# Set learning parameters
|
||||
lr_per_minibatch = learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], epoch_size, UnitType.minibatch)
|
||||
lr_per_minibatch = learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], UnitType.minibatch, epoch_size)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))
|
||||
l2_reg_weight = 0.0001
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
from cntk import Trainer
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk.learner import sgd
|
||||
from cntk.learner import sgd, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, relu, element_times, constant
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -66,8 +66,9 @@ def simple_mnist(debug_output=False):
|
|||
label : reader_train.streams.labels
|
||||
}
|
||||
|
||||
lr_per_minibatch=learning_rate_schedule(0.2, UnitType.minibatch)
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(z, ce, pe, sgd(z.parameters, lr=1./320))
|
||||
trainer = Trainer(z, ce, pe, sgd(z.parameters, lr=lr_per_minibatch))
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
minibatch_size = 64
|
||||
|
|
|
@ -170,7 +170,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
|
|||
minibatch_size = 128
|
||||
|
||||
# Set learning parameters
|
||||
lr_per_minibatch = learning_rate_schedule([1]*80 + [0.1]*40 + [0.01], epoch_size, UnitType.minibatch)
|
||||
lr_per_minibatch = learning_rate_schedule([1]*80 + [0.1]*40 + [0.01], UnitType.minibatch, epoch_size)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))
|
||||
l2_reg_weight = 0.0001
|
||||
|
||||
|
|
|
@ -68,7 +68,8 @@ def cifar_resnet_distributed(data_path, run_test, num_epochs, communicator=None,
|
|||
|
||||
num_mbs = num_mb_per_epoch * num_epochs
|
||||
|
||||
lr_per_minibatch = learning_rate_schedule([1]*80 + [0.1]*40 + [0.01], mb_size * num_mb_per_epoch, UnitType.minibatch)
|
||||
lr_schedule = [1.0/mb_size]*80 + [0.1/mb_size]*40 + [0.01/mb_size]
|
||||
lr_per_minibatch = learning_rate_schedule(lr_schedule, UnitType.minibatch, mb_size * num_mb_per_epoch)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(-mb_size/np.log(0.9))
|
||||
|
||||
# create data parallel distributed trainer if needed
|
||||
|
|
|
@ -12,7 +12,7 @@ from cntk.models import * # higher abstraction level, e.g. entire standard mode
|
|||
from cntk.utils import *
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
|
||||
from cntk import Trainer
|
||||
from cntk.learner import adam_sgd, learning_rate_schedule, momentum_as_time_constant_schedule
|
||||
from cntk.learner import adam_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
|
||||
from cntk.ops import cross_entropy_with_softmax, classification_error
|
||||
|
||||
########################
|
||||
|
@ -76,10 +76,10 @@ def train(reader, model, max_epochs):
|
|||
num_mbs_to_show_result = 100
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(minibatch_size / -math.log(0.9)) # TODO: Change to round number. This is 664.39. 700?
|
||||
|
||||
lr_schedule = [0.003]*2+[0.0015]*12+[0.0003] # LR schedule over epochs (we don't run that mayn epochs, but if we did, these are good values)
|
||||
lr_schedule = [0.003]*2+[0.0015]*12+[0.0003] # LR schedule over epochs (we don't run that many epochs, but if we did, these are good values)
|
||||
|
||||
# trainer object
|
||||
lr_per_sample = learning_rate_schedule(lr_schedule, epoch_size)
|
||||
lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size)
|
||||
learner = adam_sgd(z.parameters,
|
||||
lr=lr_per_sample, momentum=momentum_time_constant,
|
||||
low_memory=True,
|
||||
|
|
|
@ -9,7 +9,7 @@ import os
|
|||
from cntk import Trainer, Axis #, text_format_minibatch_source, StreamConfiguration
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk.learner import sgd
|
||||
from cntk.learner import sgd, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -62,9 +62,10 @@ def train_sequence_classifier(debug_output=False):
|
|||
label : reader.streams.labels
|
||||
}
|
||||
|
||||
lr_per_sample = learning_rate_schedule(0.0005, UnitType.sample)
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(classifier_output, ce, pe,
|
||||
sgd(classifier_output.parameters, lr=0.0005))
|
||||
sgd(classifier_output.parameters, lr=lr_per_sample))
|
||||
|
||||
# Get minibatches of sequences to train with and perform model training
|
||||
minibatch_size = 200
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
from cntk import Trainer, Axis, save_model, load_model #, text_format_minibatch_source, StreamConfiguration
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk.learner import momentum_sgd, momentum_as_time_constant_schedule
|
||||
from cntk.learner import learning_rate_schedule, UnitType, momentum_sgd, momentum_as_time_constant_schedule
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, past_value, future_value, element_select, alias, hardmax
|
||||
from cntk.ops.functions import CloneMethod
|
||||
|
||||
|
@ -151,13 +151,12 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
ng = z.clone(CloneMethod.share, {decoder_history_hook.output : net_output.output})
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
lr_per_sample = 0.007
|
||||
minibatch_size = 72
|
||||
lr_per_minibatch = learning_rate_schedule(0.5, UnitType.minibatch)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
clipping_threshold_per_sample = 2.3
|
||||
gradient_clipping_with_truncation = True
|
||||
learner = momentum_sgd(z.parameters,
|
||||
lr_per_sample, momentum_time_constant,
|
||||
lr_per_minibatch, momentum_time_constant,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
@ -185,6 +184,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
# Get minibatches of sequences to train with and perform model training
|
||||
i = 0
|
||||
mbs = 0
|
||||
minibatch_size = 72
|
||||
epoch_size = 908241
|
||||
max_epochs = 10
|
||||
training_progress_output_freq = 500
|
||||
|
@ -240,7 +240,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
ce = cross_entropy_with_softmax(z, label_sequence)
|
||||
errs = classification_error(z, label_sequence)
|
||||
trainer = Trainer(z, ce, errs, [momentum_sgd(
|
||||
z.parameters, lr_per_sample, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
||||
z.parameters, lr_per_minibatch, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
||||
|
||||
error2 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim)
|
||||
|
||||
|
|
|
@ -2963,14 +2963,14 @@ namespace CNTK
|
|||
///
|
||||
/// Create a schedule with a constant parameter value.
|
||||
///
|
||||
CNTK_API TrainingParameterSchedule(T value, UnitType unit = UnitType::Sample);
|
||||
CNTK_API TrainingParameterSchedule(T value, UnitType unit);
|
||||
|
||||
///
|
||||
/// Create a schedule where the parameter changes its value every 'epochSize' samples:
|
||||
/// schedule[0] is used for the first 'epochSize' samples, schedule[1] -- for the second,
|
||||
/// and so on. The last value is then used repeatedly until the end of training.
|
||||
///
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, size_t epochSize = 1, UnitType unit = UnitType::Sample);
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<T>& schedule, UnitType unit, size_t epochSize = 1);
|
||||
|
||||
///
|
||||
/// Create a schedule using the list of key-value pairs, where the key specifies
|
||||
|
@ -2981,7 +2981,7 @@ namespace CNTK
|
|||
/// the first 100 samples, then '0.1' is used for the second 200 samples,
|
||||
/// after which the values is switched to '0.005'.
|
||||
///
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, size_t epochSize = 1, UnitType unit = UnitType::Sample);
|
||||
CNTK_API TrainingParameterSchedule(const std::vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize = 1);
|
||||
|
||||
///
|
||||
/// Returns a value corresponding to the absolute sample (or sweep)
|
||||
|
@ -3033,11 +3033,11 @@ namespace CNTK
|
|||
{ }
|
||||
|
||||
TrainingParameterPerUnitSchedule(const std::vector<double>& schedule, size_t epochSize = 1)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, epochSize, U)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
|
||||
{ }
|
||||
|
||||
TrainingParameterPerUnitSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = 1)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, epochSize, U)
|
||||
: TrainingParameterSchedule<T>::TrainingParameterSchedule(schedule, U, epochSize)
|
||||
{ }
|
||||
|
||||
#ifdef SWIG // for Python interop (adds indexer)
|
||||
|
@ -3077,19 +3077,19 @@ namespace CNTK
|
|||
{
|
||||
public:
|
||||
MomentumAsTimeConstantSchedule(double value)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(value)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(value, UnitType::Sample)
|
||||
{
|
||||
ConvertToPerSampleValues();
|
||||
}
|
||||
|
||||
MomentumAsTimeConstantSchedule(const std::vector<double>& schedule, size_t epochSize = 1)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, epochSize)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
|
||||
{
|
||||
ConvertToPerSampleValues();
|
||||
}
|
||||
|
||||
MomentumAsTimeConstantSchedule(const std::vector<std::pair<size_t, double>>& schedule, size_t epochSize = 1)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, epochSize)
|
||||
: TrainingParameterSchedule<double>::TrainingParameterSchedule(schedule, UnitType::Sample, epochSize)
|
||||
{
|
||||
ConvertToPerSampleValues();
|
||||
}
|
||||
|
@ -3112,7 +3112,11 @@ namespace CNTK
|
|||
{
|
||||
double l1RegularizationWeight = 0.0;
|
||||
double l2RegularizationWeight = 0.0;
|
||||
TrainingParameterSchedule<double> gaussianNoiseInjectionStdDev = 0.0;
|
||||
#ifdef SWIG //for python interop (swig does not fully support "using")
|
||||
TrainingParameterPerUnitSchedule<double, TrainingParameterSchedule<double>::UnitType::Minibatch> gaussianNoiseInjectionStdDev = 0.0;
|
||||
#else
|
||||
TrainingParameterPerMinibatchSchedule<double> gaussianNoiseInjectionStdDev = 0.0;
|
||||
#endif
|
||||
double gradientClippingThresholdPerSample = std::numeric_limits<double>::infinity();
|
||||
bool gradientClippingWithTruncation = true;
|
||||
};
|
||||
|
@ -3171,18 +3175,11 @@ namespace CNTK
|
|||
virtual void ResetSmoothedGradients() = 0;
|
||||
|
||||
///
|
||||
/// Returns current (per-sample) learning rate.
|
||||
/// Returns current learning rate.
|
||||
///
|
||||
virtual double LearningRate(size_t minibatchSize = 1) const
|
||||
virtual double LearningRate() const
|
||||
{
|
||||
auto learningRate = GetCurrentTrainingParameterValue<double>(m_learningRateSchedule);
|
||||
if (m_learningRateSchedule.Unit() == LearningRateSchedule::UnitType::Minibatch)
|
||||
{
|
||||
// learning rate needs to be converted to the per-sample value.
|
||||
return (minibatchSize == 0) ? 0.0 : learningRate / minibatchSize;
|
||||
}
|
||||
|
||||
return learningRate;
|
||||
return GetCurrentTrainingParameterValue<double>(m_learningRateSchedule);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -41,6 +41,19 @@ namespace CNTK
|
|||
|
||||
std::string LearnerType() const;
|
||||
|
||||
// Returns current (per-sample) learning rate.
|
||||
double LearningRate(size_t minibatchSize) const
|
||||
{
|
||||
auto learningRate = Learner::LearningRate();
|
||||
if (m_learningRateSchedule.Unit() == LearningRateSchedule::UnitType::Minibatch)
|
||||
{
|
||||
// learning rate needs to be converted to the per-sample value.
|
||||
return (minibatchSize == 0) ? 0.0 : learningRate / minibatchSize;
|
||||
}
|
||||
|
||||
return learningRate;
|
||||
}
|
||||
|
||||
AdditionalLearningOptions m_additionalOptions;
|
||||
|
||||
std::unordered_map<Parameter, NDArrayViewPtr> m_smoothedGradientValues;
|
||||
|
|
|
@ -244,7 +244,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<T>& schedule, size_t epochSize, UnitType unit)
|
||||
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<T>& schedule, UnitType unit, size_t epochSize)
|
||||
: m_unit(unit), m_epochSize(epochSize)
|
||||
{
|
||||
std::vector<std::pair<size_t, T>> s(schedule.size());
|
||||
|
@ -257,7 +257,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<std::pair<size_t, T>>& schedule, size_t epochSize, UnitType unit)
|
||||
TrainingParameterSchedule<T>::TrainingParameterSchedule(const vector<std::pair<size_t, T>>& schedule, UnitType unit, size_t epochSize)
|
||||
: m_unit(unit), m_epochSize(epochSize)
|
||||
{
|
||||
ConstructSchedule(schedule);
|
||||
|
|
|
@ -50,7 +50,7 @@ void TrainSimpleDistributedFeedForwardClassifer(const DeviceDescriptor& device,
|
|||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");;
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
|
||||
double learningRatePerSample = 0.02;
|
||||
auto learningRatePerSample = LearningRatePerSampleSchedule(0.02);
|
||||
minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) }, distributedTrainer);
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
|
|
|
@ -157,7 +157,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
|
|||
classifierOutput = classifierOutputVar;
|
||||
}
|
||||
|
||||
double learningRatePerSample = 0.0078125;
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.0078125;
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
|
||||
const size_t minibatchSize = 32;
|
||||
|
|
|
@ -51,7 +51,7 @@ void TestSGDLearner(size_t numParameters, size_t numMinibatches, const DeviceDes
|
|||
{
|
||||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
auto learner = SGDLearner(parameters, 0.4);
|
||||
auto learner = SGDLearner(parameters, LearningRatePerSampleSchedule(0.4));
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
}
|
||||
|
||||
|
@ -61,10 +61,10 @@ void TestMomentumSGDLearner(size_t numParameters, size_t numMinibatches, const D
|
|||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
LearningRatePerMinibatchSchedule learnigRateSchedule = { { 3.0, 2.0, 1.0 }, numMinibatches };
|
||||
MomentumSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 };
|
||||
MomentumPerSampleSchedule momentumValues = { { { 1, 1.0 }, { 3, 0.1 }, { 10, 0.01 } }, 2 };
|
||||
auto learner = MomentumSGDLearner(parameters, learnigRateSchedule, momentumValues);
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
FloatingPointCompare(learner->LearningRate(100), 0.02, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 2.0, "Learner::LearningRate does not match expectation");
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
|
@ -91,7 +91,7 @@ void TestFSAdaGradLearner(size_t numParameters, size_t numMinibatches, const Dev
|
|||
{
|
||||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
auto learner = AdamLearner(parameters, { { 0.5 } }, MomentumAsTimeConstantSchedule({ 10, 100, 1000 }));
|
||||
auto learner = AdamLearner(parameters, LearningRatePerSampleSchedule({ 0.5 }), MomentumAsTimeConstantSchedule({ 10.0, 100.0, 1000.0 }));
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ void TestRMSPropLearner(size_t numParameters, size_t numMinibatches, const Devic
|
|||
{
|
||||
NDShape shape = CreateShape(rng() % maxNumAxes + 1, maxDimSize);
|
||||
auto parameters = CreateParameters<ElementType>(shape, numParameters, device);
|
||||
auto learner = RMSPropLearner(parameters, { { { 3, 0.7 }, { 1, 0.2 } } }, 0.01, 0.02, 0.03, 0.1, 0.001);
|
||||
auto learner = RMSPropLearner(parameters, LearningRatePerMinibatchSchedule({ { 3, 0.7 }, { 1, 0.2 } }), 0.01, 0.02, 0.03, 0.1, 0.001);
|
||||
TestUpdate<ElementType>(learner, shape, numMinibatches, device);
|
||||
}
|
||||
|
||||
|
@ -110,19 +110,19 @@ void TestTrainingParametersSchedule()
|
|||
LearningRatePerMinibatchSchedule({ 3.0, 2.0, 1.0 }, LearningRateSchedule::EntireSweep);
|
||||
}, "Was able to create not-yet-implemented sweep-based schedule.");
|
||||
|
||||
LearningRateSchedule schedule1 = 0.5;
|
||||
LearningRatePerSampleSchedule schedule1 = 0.5;
|
||||
assert(schedule1.Unit() == LearningRateSchedule::UnitType::Sample);
|
||||
assert(schedule1[0] == 0.5);
|
||||
assert(schedule1[1] == 0.5);
|
||||
assert(schedule1[100] == 0.5);
|
||||
|
||||
LearningRateSchedule schedule2 = { 0.5 };
|
||||
LearningRatePerSampleSchedule schedule2 = { 0.5 };
|
||||
assert(schedule2.Unit() == LearningRateSchedule::UnitType::Sample);
|
||||
assert(schedule2[0] == 0.5);
|
||||
assert(schedule2[10] == 0.5);
|
||||
assert(schedule2[100] == 0.5);
|
||||
|
||||
LearningRateSchedule schedule3 = { { 0.5, 0.3, 0.3 } };
|
||||
LearningRatePerSampleSchedule schedule3({ 0.5, 0.3, 0.3 });
|
||||
assert(schedule3.Unit() == LearningRateSchedule::UnitType::Sample);
|
||||
assert(schedule3[0] == 0.5);
|
||||
assert(schedule3[1] == 0.3);
|
||||
|
@ -143,8 +143,8 @@ void TestTrainingParametersSchedule()
|
|||
assert(schedule5[20] == 0.2);
|
||||
assert(schedule5[100] == 0.2);
|
||||
|
||||
MomentumSchedule schedule6 = { { make_pair(1, 0.5) } }; // without make_pair this is interpreted as a vector of doubles
|
||||
assert(schedule6.Unit() == MomentumSchedule::UnitType::Sample);
|
||||
MomentumPerMinibatchSchedule schedule6 = { { make_pair(1, 0.5) } }; // without make_pair this is interpreted as a vector of doubles
|
||||
assert(schedule6.Unit() == MomentumSchedule::UnitType::Minibatch);
|
||||
assert(schedule6[0] == 0.5);
|
||||
assert(schedule6[10] == 0.5);
|
||||
assert(schedule6[100] == 0.5);
|
||||
|
@ -165,7 +165,7 @@ void TestTrainingParametersSchedule()
|
|||
assert(schedule8[20] == 0.2);
|
||||
assert(schedule8[100] == 0.2);
|
||||
|
||||
LearningRateSchedule schedule9 = { { { 3, 0.5 }, { 2, 0.3 }, { 1, 0.2 } } };
|
||||
LearningRatePerSampleSchedule schedule9 = { { { 3, 0.5 }, { 2, 0.3 }, { 1, 0.2 } } };
|
||||
assert(schedule9.Unit() == LearningRateSchedule::UnitType::Sample);
|
||||
assert(schedule9[0] == 0.5);
|
||||
assert(schedule9[2] == 0.5);
|
||||
|
|
|
@ -175,13 +175,12 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
auto rawInputStreamInfo = minibatchSource->StreamInfo(featureStreamName);
|
||||
auto rawLabelsStreamInfo = minibatchSource->StreamInfo(labelStreamName);
|
||||
|
||||
double learningRatePerSample = 0.007;
|
||||
size_t momentumTimeConstant = 1100;
|
||||
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.007;
|
||||
MomentumAsTimeConstantSchedule momentumTimeConstant = 1100;
|
||||
AdditionalLearningOptions additionalOptions;
|
||||
additionalOptions.gradientClippingThresholdPerSample = 2.3;
|
||||
additionalOptions.gradientClippingWithTruncation = true;
|
||||
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumPerSample, additionalOptions) });
|
||||
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumTimeConstant, additionalOptions) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
size_t minibatchSize1 = 72;
|
||||
|
|
|
@ -46,10 +46,9 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool useSparseLa
|
|||
auto featureStreamInfo = minibatchSource->StreamInfo(featuresName);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labelsName);
|
||||
|
||||
double learningRatePerSample = 0.0005;
|
||||
size_t momentumTimeConstant = 256;
|
||||
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumPerSample) });
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.0005;
|
||||
MomentumAsTimeConstantSchedule momentumTimeConstant = 256;
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
for (size_t i = 0; true; i++)
|
||||
|
@ -91,68 +90,68 @@ void TestLearningRateControl(const DeviceDescriptor& device)
|
|||
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { learner });
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
const wchar_t* modelFile = L"seq2seq.model";
|
||||
trainer.SaveCheckpoint(modelFile);
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto MB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto MB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.RestoreFromCheckpoint(modelFile);
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.00025, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
trainer.RestoreFromCheckpoint(modelFile);
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0005, "Learner::LearningRate does not match expectation");
|
||||
|
||||
learner->ResetLearningRate(0.0004);
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
learner->ResetLearningRate(LearningRatePerSampleSchedule(0.0004));
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.SaveCheckpoint(modelFile);
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.RestoreFromCheckpoint(modelFile);
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB2Loss, MB2Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
FloatingPointCompare(postRestoreMB3Loss, MB3Loss, "Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
FloatingPointCompare(learner->LearningRate(minibatchSize), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
FloatingPointCompare(learner->LearningRate(), 0.0004, "Learner::LearningRate does not match expectation");
|
||||
}
|
||||
|
||||
void TrainLSTMSequenceClassifer()
|
||||
|
|
|
@ -203,7 +203,7 @@ void TestLearnerSerialization(int numParameters, const DeviceDescriptor& device)
|
|||
gradientValues[parameter] = NDArrayView::RandomUniform<ElementType>(shape, -0.5, 0.5, numParameters + i, device);
|
||||
}
|
||||
|
||||
auto learner1 = SGDLearner(parameters, 0.05);
|
||||
auto learner1 = SGDLearner(parameters, LearningRatePerSampleSchedule(0.05));
|
||||
|
||||
learner1->Update(gradientValues, 1);
|
||||
|
||||
|
@ -215,7 +215,7 @@ void TestLearnerSerialization(int numParameters, const DeviceDescriptor& device)
|
|||
stream.flush();
|
||||
}
|
||||
|
||||
auto learner2 = SGDLearner(parameters, 0.05);
|
||||
auto learner2 = SGDLearner(parameters, LearningRatePerSampleSchedule( 0.05));
|
||||
|
||||
{
|
||||
Dictionary checkpoint;
|
||||
|
@ -405,7 +405,9 @@ void TestFunctionSerialization(const DeviceDescriptor& device)
|
|||
TestFunctionSaveAndLoad(BuildLSTMClassifierNet(inputVar, 5, device), device);
|
||||
}
|
||||
|
||||
Trainer BuildTrainer(const FunctionPtr& function, const Variable& labels, LearningRateSchedule lr = 0.005, MomentumSchedule m = 0.0)
|
||||
Trainer BuildTrainer(const FunctionPtr& function, const Variable& labels,
|
||||
LearningRateSchedule lr = LearningRatePerSampleSchedule(0.005),
|
||||
MomentumSchedule m = MomentumAsTimeConstantSchedule(0.0))
|
||||
{
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(function, labels, L"lossFunction");
|
||||
auto prediction = CNTK::ClassificationError(function, labels, L"classificationError");
|
||||
|
@ -498,7 +500,7 @@ void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionP
|
|||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
|
||||
|
||||
LearningRateSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize);
|
||||
|
||||
|
||||
|
@ -633,7 +635,7 @@ void TestLegacyModelSaving(const DeviceDescriptor& device)
|
|||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
|
||||
|
||||
LearningRateSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
LearningRatePerSampleSchedule learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { learner });
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
|
|||
prediction = predictionVar;
|
||||
}
|
||||
|
||||
double learningRatePerSample = 0.02;
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.02;
|
||||
minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
|
@ -121,7 +121,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
|
|||
auto featureStreamInfo = minibatchSource->StreamInfo(featureStreamName);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labelsStreamName);
|
||||
|
||||
double learningRatePerSample = 0.003125;
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.003125;
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
|
|
|
@ -109,10 +109,9 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
|
|||
featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||
|
||||
double learningRatePerSample = 0.000781;
|
||||
size_t momentumTimeConstant = 6074;
|
||||
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
|
||||
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumPerSample);
|
||||
LearningRatePerSampleSchedule learningRatePerSample = 0.000781;
|
||||
MomentumAsTimeConstantSchedule momentumTimeConstant = 6074;
|
||||
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumTimeConstant);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, {learner});
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
|
|
|
@ -9,7 +9,7 @@ import sys
|
|||
import os
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk import Trainer
|
||||
from cntk.learner import sgd
|
||||
from cntk.learner import sgd, learning_rate_schedule, UnitType
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sigmoid
|
||||
from cntk.utils import ProgressPrinter
|
||||
|
||||
|
@ -52,8 +52,9 @@ def ffnet():
|
|||
ce = cross_entropy_with_softmax(netout, label)
|
||||
pe = classification_error(netout, label)
|
||||
|
||||
lr_per_minibatch=learning_rate_schedule(0.5, UnitType.minibatch)
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=0.02))
|
||||
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=lr_per_minibatch))
|
||||
|
||||
# Get minibatches of training data and perform model training
|
||||
minibatch_size = 25
|
||||
|
|
|
@ -50,6 +50,33 @@ the following learning algorithms:
|
|||
+------------------------+
|
||||
'''
|
||||
|
||||
# an internal method to verify that the learning rate schedule
|
||||
# has a proper (per-sample or per-MB schedule) type and raise
|
||||
# an exception otherwise
|
||||
def _verify_learning_rate_type(learning_rate):
|
||||
if not isinstance(learning_rate,
|
||||
(cntk_py.training_parameter_per_sample_schedule,
|
||||
cntk_py.training_parameter_per_minibatch_schedule)):
|
||||
|
||||
raise ValueError('learning_rate type (%s) not supported. '
|
||||
'learning_rate must be a training schedule '
|
||||
'(output of learning_rate_schedule() function)'
|
||||
% type(learning_rate))
|
||||
|
||||
# an internal method to verify that the mometum schedule
|
||||
# has a proper (per-MB or time-constant schedule) type and raise
|
||||
# an exception otherwise
|
||||
def _verify_momentum_type(momentum):
|
||||
if not isinstance(momentum,
|
||||
(cntk_py.training_parameter_per_minibatch_schedule,
|
||||
cntk_py.momentum_as_time_constant_schedule)):
|
||||
|
||||
raise ValueError('momentum type (%s) not supported. '
|
||||
'momentum must be a training schedule '
|
||||
'(output of momentum_schedule() or '
|
||||
'momentum_as_time_constant_schedule() function)'
|
||||
% type(momentum))
|
||||
|
||||
class Learner(cntk_py.Learner):
|
||||
'''
|
||||
Abstraction for learning a subset of parameters of a learnable function using first order gradient values
|
||||
|
@ -63,7 +90,7 @@ class Learner(cntk_py.Learner):
|
|||
Update the parameters associated with this learner.
|
||||
|
||||
Args:
|
||||
gradient_values (`dict`): maps :class:`~cntk.variables.Parameter` to
|
||||
gradient_values (dict): maps :class:`~cntk.variables.Parameter` to
|
||||
a NumPy array containing the first order gradient values for the
|
||||
Parameter w.r.t. the training objective.
|
||||
training_sample_count (int): training sample count
|
||||
|
@ -85,49 +112,43 @@ class Learner(cntk_py.Learner):
|
|||
'''
|
||||
return super(Learner, self).parameters()
|
||||
|
||||
|
||||
def reset_learning_rate(self, learning_rate):
|
||||
'''
|
||||
Resets the learning rate.
|
||||
|
||||
Args:
|
||||
learning_rate (float, list or a training schedule): learning rate
|
||||
to reset to
|
||||
learning_rate (output of :func:`learning_rate_schedule`)
|
||||
learning rate to reset to
|
||||
'''
|
||||
learning_rate = learning_rate_schedule(learning_rate)
|
||||
_verify_learning_rate_type(learning_rate)
|
||||
return super(Learner, self).reset_learning_rate(learning_rate)
|
||||
|
||||
def learning_rate(self, minibatch_size=1):
|
||||
def learning_rate(self):
|
||||
'''
|
||||
The learning rate.
|
||||
|
||||
Args:
|
||||
minibatch_size (int): minibatch size to re-scaled
|
||||
the learning rate to the per-sample value (in case when the schedule
|
||||
was build with ``unit=UnitType.minibatch``).
|
||||
Current learning rate.
|
||||
'''
|
||||
return super(Learner, self).learning_rate(minibatch_size)
|
||||
return super(Learner, self).learning_rate()
|
||||
|
||||
@typemap
|
||||
def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
|
||||
def training_parameter_schedule(schedule, unit, epoch_size=1):
|
||||
'''
|
||||
Create a training parameter schedule containing either per-sample (default)
|
||||
or per-minibatch values.
|
||||
|
||||
Examples:
|
||||
>>> # Use a fixed value 0.01 for all samples
|
||||
>>> s = training_parameter_schedule(0.01)
|
||||
>>> s = training_parameter_schedule(0.01, UnitType.sample)
|
||||
>>> s[0], s[1]
|
||||
(0.01, 0.01)
|
||||
|
||||
>>> # Use 0.01 for the first 1000 samples, then 0.001 for the remaining ones
|
||||
>>> s = training_parameter_schedule([0.01, 0.001], 1000)
|
||||
>>> s = training_parameter_schedule([0.01, 0.001], UnitType.sample, 1000)
|
||||
>>> s[0], s[1], s[1000], s[1001]
|
||||
(0.01, 0.01, 0.001, 0.001)
|
||||
|
||||
>>> # Use 0.1 for the first 12 epochs, then 0.01 for the next 15,
|
||||
>>> # followed by 0.001 for the remaining ones, with a 100 samples in an epoch
|
||||
>>> s = training_parameter_schedule([(12, 0.1), (15, 0.01), (1, 0.001)], 100)
|
||||
>>> s = training_parameter_schedule([(12, 0.1), (15, 0.01), (1, 0.001)], UnitType.sample, 100)
|
||||
>>> s[0], s[1199], s[1200], s[2699], s[2700], s[5000]
|
||||
(0.1, 0.1, 0.01, 0.01, 0.001, 0.001)
|
||||
|
||||
|
@ -136,12 +157,11 @@ def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
|
|||
for all samples. In case of list, the elements are used as the
|
||||
values for ``epoch_size`` samples. If list contains pair, the second element is
|
||||
used as a value for (``epoch_size`` x first element) samples
|
||||
unit (:class:`UnitType`): one of two
|
||||
* ``sample``: the returned schedule contains per-sample values
|
||||
* ``minibatch``: the returned schedule contains per-minibatch values.
|
||||
epoch_size (int): number of samples as a scheduling unit. Parameters in
|
||||
the schedule change their values every ``epoch_size`` samples.
|
||||
unit (:class:`UnitType`): one of two
|
||||
|
||||
* ``sample``: the returned schedule contains per-sample values (default)
|
||||
* ``minibatch``: the returned schedule contains per-minibatch values.
|
||||
|
||||
Returns:
|
||||
training parameter schedule
|
||||
|
@ -153,10 +173,12 @@ def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
|
|||
raise ValueError('schedule unit "%s" is not supported' %
|
||||
str(method))
|
||||
|
||||
if isinstance(schedule, (cntk_py.training_parameter_per_sample_schedule,
|
||||
cntk_py.training_parameter_per_minibatch_schedule,
|
||||
cntk_py.momentum_as_time_constant_schedule)):
|
||||
return schedule
|
||||
if unit == UnitType.sample:
|
||||
if isinstance(schedule, cntk_py.training_parameter_per_sample_schedule):
|
||||
return schedule
|
||||
else:
|
||||
if isinstance(schedule, cntk_py.training_parameter_per_minibatch_schedule):
|
||||
return schedule
|
||||
|
||||
if isinstance(schedule, (int, float)):
|
||||
if unit is UnitType.sample:
|
||||
|
@ -173,7 +195,7 @@ def training_parameter_schedule(schedule, epoch_size=1, unit=UnitType.sample):
|
|||
raise ValueError('schedule must be either a float or a list, not %s'%type(schedule))
|
||||
|
||||
@typemap
|
||||
def learning_rate_schedule(lr, epoch_size=1, unit=UnitType.sample):
|
||||
def learning_rate_schedule(lr, unit, epoch_size=1):
|
||||
'''
|
||||
Create a learning rate schedule (using the same semantics as
|
||||
:func:`training_parameter_schedule`).
|
||||
|
@ -181,10 +203,10 @@ def learning_rate_schedule(lr, epoch_size=1, unit=UnitType.sample):
|
|||
Args:
|
||||
lr (float or list): see parameter ``schedule`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
epoch_size (int): see parameter ``epoch_size`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
unit (:class:`UnitType`): see parameter
|
||||
``unit`` in :func:`training_parameter_schedule`.
|
||||
epoch_size (int): see parameter ``epoch_size`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
|
||||
Returns:
|
||||
learning rate schedule
|
||||
|
@ -192,23 +214,21 @@ def learning_rate_schedule(lr, epoch_size=1, unit=UnitType.sample):
|
|||
See also:
|
||||
:func:`training_parameter_schedule`
|
||||
'''
|
||||
return training_parameter_schedule(lr, epoch_size, unit)
|
||||
return training_parameter_schedule(lr, unit, epoch_size)
|
||||
|
||||
@typemap
|
||||
def momentum_schedule(momentum, epoch_size=1, unit=UnitType.sample):
|
||||
def momentum_schedule(momentum, epoch_size=1):
|
||||
'''
|
||||
Create a momentum schedule (using the same semantics as
|
||||
:func:`training_parameter_schedule`).
|
||||
Create a per-minibatch momentum schedule (using the same semantics as
|
||||
:func:`training_parameter_schedule` with the `unit=UnitType.minibatch`).
|
||||
|
||||
Args:
|
||||
momentum (float or list): see parameter ``schedule`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
epoch_size (int): see parameter ``epoch_size`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
unit (:class:`UnitType`): see parameter
|
||||
``unit`` in :func:`training_parameter_schedule`.
|
||||
|
||||
If you want to provide momentum values in a sample/minibatch
|
||||
If you want to provide momentum values in a minibatch-size
|
||||
agnostic way, use :func:`momentum_as_time_constant_schedule`.
|
||||
|
||||
Examples:
|
||||
|
@ -228,32 +248,23 @@ def momentum_schedule(momentum, epoch_size=1, unit=UnitType.sample):
|
|||
>>> m[0], m[998], m[999], m[999+888-1], m[999+888]
|
||||
(0.99, 0.99, 0.88, 0.88, 0.77)
|
||||
|
||||
Args:
|
||||
momentum (float or list): see parameter ``schedule`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
epoch_size (int): see parameter ``epoch_size`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
unit (:class:`UnitType`): see parameter
|
||||
``unit`` in :func:`training_parameter_schedule`.
|
||||
|
||||
Returns:
|
||||
momentum schedule
|
||||
'''
|
||||
return training_parameter_schedule(momentum, epoch_size, unit)
|
||||
return training_parameter_schedule(momentum, UnitType.minibatch, epoch_size)
|
||||
|
||||
@typemap
|
||||
def momentum_as_time_constant_schedule(momentum, epoch_size=1):
|
||||
'''
|
||||
Create a momentum schedule in a minibatch agnostic way (using the same
|
||||
semantics as :func:`training_parameter_schedule`).
|
||||
Create a momentum schedule in a minibatch-size agnostic way
|
||||
(using the same semantics as :func:`training_parameter_schedule`
|
||||
with `unit=UnitType.sample`).
|
||||
|
||||
Args:
|
||||
momentum (float or list): see parameter ``schedule`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
epoch_size (int): see parameter ``epoch_size`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
unit (:class:`UnitType`): see parameter
|
||||
``unit`` in :func:`training_parameter_schedule`.
|
||||
|
||||
CNTK specifies momentum in a minibatch-size agnostic way as the time
|
||||
constant (in samples) of a unit-gain 1st-order IIR filter. The value
|
||||
|
@ -263,7 +274,6 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
|
|||
If you want to specify the momentum per sample (or per minibatch),
|
||||
use :func:`momentum_schedule`.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> # Use a fixed momentum of 1100 for all samples
|
||||
>>> m = momentum_as_time_constant_schedule(1100)
|
||||
|
@ -272,18 +282,10 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
|
|||
>>> # then 1500 for the remaining ones
|
||||
>>> m = momentum_as_time_constant_schedule([1100, 1500], 1000)
|
||||
|
||||
Args:
|
||||
momentum (float or list): see parameter ``schedule`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
epoch_size (int): see parameter ``epoch_size`` in
|
||||
:func:`training_parameter_schedule`.
|
||||
|
||||
Returns:
|
||||
momentum as time constant schedule
|
||||
'''
|
||||
if isinstance(momentum, (cntk_py.training_parameter_per_sample_schedule,
|
||||
cntk_py.training_parameter_per_minibatch_schedule,
|
||||
cntk_py.momentum_as_time_constant_schedule)):
|
||||
if isinstance(momentum, (cntk_py.momentum_as_time_constant_schedule)):
|
||||
return momentum
|
||||
|
||||
if isinstance(momentum, (int, float)):
|
||||
|
@ -293,7 +295,6 @@ def momentum_as_time_constant_schedule(momentum, epoch_size=1):
|
|||
|
||||
raise ValueError('momentum must be either a float or a list, not %s'%type(momentum))
|
||||
|
||||
|
||||
# TODO figure out how to pass infty to C++ in a portable way
|
||||
@typemap
|
||||
def sgd(parameters, lr,
|
||||
|
@ -308,9 +309,7 @@ def sgd(parameters, lr,
|
|||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the '.parameters()' method of the root
|
||||
operator.
|
||||
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
|
||||
schedule. When the argument value is a float or a list, lr is
|
||||
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -319,7 +318,8 @@ def sgd(parameters, lr,
|
|||
of the Gaussian noise added to parameters post update, defaults to 0.0
|
||||
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
|
||||
per sample, defaults to infinity
|
||||
gradient_clipping_with_truncation (bool, default ``True``): gradient clipping
|
||||
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
|
||||
with truncation
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
|
||||
|
@ -329,8 +329,9 @@ def sgd(parameters, lr,
|
|||
<http://research.microsoft.com/pubs/192769/tricks-2012.pdf>`_. Neural
|
||||
Networks: Tricks of the Trade: Springer, 2012.
|
||||
'''
|
||||
lr = learning_rate_schedule(lr)
|
||||
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
|
||||
_verify_learning_rate_type(lr)
|
||||
gaussian_noise_injection_std_dev = \
|
||||
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
|
||||
|
||||
additional_options = cntk_py.AdditionalLearningOptions()
|
||||
additional_options.l1_regularization_weight = l1_regularization_weight
|
||||
|
@ -347,17 +348,15 @@ def momentum_sgd(parameters, lr, momentum,
|
|||
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=1E10,
|
||||
gradient_clipping_with_truncation=True):
|
||||
'''
|
||||
Creates a Momemtum SGD learner instance to learn the parameters.
|
||||
Creates a Momentum SGD learner instance to learn the parameters.
|
||||
|
||||
Args:
|
||||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the root operator's ``parameters``.
|
||||
lr (float, list```` or output of :func:`learning_rate_schedule`): learning rate
|
||||
schedule. When the argument value is a float or a list, lr is
|
||||
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
|
||||
momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule. When the argument
|
||||
value is a float or a list, momentum is converted to a per-sample schedule by
|
||||
invoking :func:`momentum_schedule`. Refer to the `wiki
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
momentum (output of :func:`momentum_schedule` or
|
||||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
|
@ -367,14 +366,16 @@ def momentum_sgd(parameters, lr, momentum,
|
|||
of the Gaussian noise added to parameters post update, defaults to 0.0
|
||||
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
|
||||
per sample, defaults to infinity
|
||||
gradient_clipping_with_truncation (bool, default ``True``): gradient clipping
|
||||
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
|
||||
with truncation
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
|
||||
Instance of a :class:`cntk.learner.Learner` that can be passed to the :class:`cntk.trainer.Trainer`
|
||||
'''
|
||||
lr = learning_rate_schedule(lr)
|
||||
momentum = momentum_schedule(momentum)
|
||||
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
|
||||
_verify_learning_rate_type(lr)
|
||||
_verify_momentum_type(momentum)
|
||||
gaussian_noise_injection_std_dev = \
|
||||
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
|
||||
|
||||
additional_options = cntk_py.AdditionalLearningOptions()
|
||||
additional_options.l1_regularization_weight = l1_regularization_weight
|
||||
|
@ -399,12 +400,10 @@ def nesterov(parameters, lr, momentum,
|
|||
Args:
|
||||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the root operator's ``parameters``.
|
||||
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
|
||||
schedule. When the argument value is a float or a list, lr is
|
||||
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
|
||||
momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule. When the argument
|
||||
value is a float or a list, momentum is converted to a per-sample schedule by
|
||||
invoking :func:`momentum_schedule`. Refer to the `wiki
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
momentum (output of :func:`momentum_schedule` or
|
||||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
|
@ -414,7 +413,8 @@ def nesterov(parameters, lr, momentum,
|
|||
of the Gaussian noise added to parameters post update, defaults to 0.0
|
||||
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
|
||||
per sample, defaults to infinity
|
||||
gradient_clipping_with_truncation (bool, default ``True``): gradient clipping
|
||||
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
|
||||
with truncation
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`~cntk.learner.Learner` that can be passed to the
|
||||
|
@ -429,9 +429,10 @@ def nesterov(parameters, lr, momentum,
|
|||
of the 30th International Conference on Machine Learning, 2013.
|
||||
|
||||
'''
|
||||
lr = learning_rate_schedule(lr)
|
||||
momentum = momentum_schedule(momentum)
|
||||
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
|
||||
_verify_learning_rate_type(lr)
|
||||
_verify_momentum_type(momentum)
|
||||
gaussian_noise_injection_std_dev = \
|
||||
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
|
||||
|
||||
additional_options = cntk_py.AdditionalLearningOptions()
|
||||
additional_options.l1_regularization_weight = l1_regularization_weight
|
||||
|
@ -455,9 +456,7 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
|
|||
Args:
|
||||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the root operator's ``parameters``.
|
||||
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
|
||||
schedule. When the argument value is a float or a list, lr is
|
||||
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
need_ave_multiplier (bool, default):
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
|
@ -467,7 +466,8 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
|
|||
of the Gaussian noise added to parameters post update, defaults to 0.0
|
||||
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
|
||||
per sample, defaults to infinity
|
||||
gradient_clipping_with_truncation (bool, default `True`): gradient clipping
|
||||
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
|
||||
with truncation
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
|
||||
|
@ -478,8 +478,9 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
|
|||
<http://www.magicbroom.info/Papers/DuchiHaSi10.pdf>`_. The Journal of
|
||||
Machine Learning Research, 2011.
|
||||
'''
|
||||
lr = learning_rate_schedule(lr)
|
||||
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
|
||||
_verify_learning_rate_type(lr)
|
||||
gaussian_noise_injection_std_dev = \
|
||||
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
|
||||
|
||||
additional_options = cntk_py.AdditionalLearningOptions()
|
||||
additional_options.l1_regularization_weight = l1_regularization_weight
|
||||
|
@ -506,16 +507,14 @@ def adam_sgd(parameters, lr, momentum,
|
|||
Args:
|
||||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the root operator's ``parameters``.
|
||||
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
|
||||
schedule. When the argument value is a float or a list, lr is
|
||||
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
|
||||
momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule. When the argument
|
||||
value is a float or a list, momentum is converted to a per-sample schedule by
|
||||
invoking :func:`momentum_schedule`. Refer to the `wiki
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
momentum (output of :func:`momentum_schedule` or
|
||||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
variance_momentum (float, list or output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): variance momentum schedule. When the argument
|
||||
value is a float or a list, variance momentum is converted to a per-sample schedule by
|
||||
invoking :func:`momentum_schedule`. Defaults to momentum_as_time_constant_schedule(720000).
|
||||
variance_momentum (output of :func:`momentum_schedule` or
|
||||
:func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
|
||||
to ``momentum_as_time_constant_schedule(720000)``.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -524,7 +523,8 @@ def adam_sgd(parameters, lr, momentum,
|
|||
of the Gaussian noise added to parameters post update, defaults to 0.0
|
||||
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
|
||||
per sample, defaults to infinity
|
||||
gradient_clipping_with_truncation (bool, default `True`): gradient clipping
|
||||
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
|
||||
with truncation
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
|
||||
|
@ -537,10 +537,11 @@ def adam_sgd(parameters, lr, momentum,
|
|||
if not low_memory:
|
||||
raise NotImplementedError('adam: low_memory=True currently required')
|
||||
|
||||
lr = learning_rate_schedule(lr)
|
||||
momentum = momentum_schedule(momentum)
|
||||
variance_momentum = momentum_schedule(variance_momentum)
|
||||
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
|
||||
_verify_learning_rate_type(lr)
|
||||
_verify_momentum_type(momentum)
|
||||
_verify_momentum_type(variance_momentum)
|
||||
gaussian_noise_injection_std_dev = \
|
||||
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
|
||||
|
||||
additional_options = cntk_py.AdditionalLearningOptions()
|
||||
additional_options.l1_regularization_weight = l1_regularization_weight
|
||||
|
@ -565,15 +566,13 @@ def rmsprop(parameters, lr,
|
|||
Args:
|
||||
parameters (list of parameters): list of network parameters to tune.
|
||||
These can be obtained by the root operator's ``parameters``.
|
||||
lr (float, list or output of :func:`learning_rate_schedule`): learning rate
|
||||
schedule. When the argument value is a float or a list, lr is
|
||||
converted to a per-sample schedule by invoking :func:`learning_rate_schedule`.
|
||||
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
|
||||
gamma (float):
|
||||
inc (float):
|
||||
dec (float):
|
||||
max (float):
|
||||
min (float):
|
||||
need_ave_multiplier (bool, default):
|
||||
need_ave_multiplier (bool, default ``True``):
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -582,13 +581,15 @@ def rmsprop(parameters, lr,
|
|||
of the Gaussian noise added to parameters post update, defaults to 0.0
|
||||
gradient_clipping_threshold_per_sample (float, optional): clipping threshold
|
||||
per sample, defaults to infinity
|
||||
gradient_clipping_with_truncation (bool, default `True`): gradient clipping
|
||||
gradient_clipping_with_truncation (bool, default ``True``): use gradient clipping
|
||||
with truncation
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`~cntk.learner.Learner` that can be passed to the :class:`~cntk.trainer.Trainer`
|
||||
'''
|
||||
lr = learning_rate_schedule(lr)
|
||||
gaussian_noise_injection_std_dev = training_parameter_schedule(gaussian_noise_injection_std_dev)
|
||||
_verify_learning_rate_type(lr)
|
||||
gaussian_noise_injection_std_dev = \
|
||||
training_parameter_schedule(gaussian_noise_injection_std_dev, UnitType.minibatch)
|
||||
|
||||
additional_options = cntk_py.AdditionalLearningOptions()
|
||||
additional_options.l1_regularization_weight = l1_regularization_weight
|
||||
|
|
|
@ -40,9 +40,9 @@ def run_distributed_trainer(tmpdir, quantized):
|
|||
dist_trainer = distributed.data_parallel_distributed_trainer(communicator, False)
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
trainer = Trainer(z, ce, errs, \
|
||||
momentum_sgd(z.parameters, 0.007, momentum_time_constant),
|
||||
momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant),
|
||||
distributed_trainer=dist_trainer)
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
|
|
|
@ -10,21 +10,29 @@ from .. import parameter, input_variable
|
|||
|
||||
import pytest
|
||||
|
||||
SCHEDULE_PARAMS = [
|
||||
LR_SCHEDULE_PARAMS = [
|
||||
((0.2, UnitType.sample), [0.2]),
|
||||
((0.2, UnitType.sample), [0.2, 0.2, 0.2, 0.2]),
|
||||
(([0.2,0.4], UnitType.sample, 5), [0.2]*5+[0.4]*20),
|
||||
(([(3,0.2),(2,0.4),(1,0.8)], UnitType.sample, 5), [0.2]*15+[0.4]*10+[0.8]*20),
|
||||
]
|
||||
|
||||
MOMENTUM_SCHEDULE_PARAMS = [
|
||||
((0.2,), [0.2]),
|
||||
((0.2,), [0.2, 0.2, 0.2, 0.2]),
|
||||
(([0.2,0.4], 5), [0.2]*5+[0.4]*20),
|
||||
(([(3,0.2),(2,0.4),(1,0.8)], 5), [0.2]*15+[0.4]*10+[0.8]*20),
|
||||
]
|
||||
@pytest.mark.parametrize("params, expectation", SCHEDULE_PARAMS)
|
||||
|
||||
@pytest.mark.parametrize("params, expectation", LR_SCHEDULE_PARAMS)
|
||||
def test_learning_rate_schedule(params, expectation):
|
||||
l = learning_rate_schedule(*params)
|
||||
assert [l[i] for i in range(len(expectation))] == expectation
|
||||
|
||||
def sweep_based_schedule_fails():
|
||||
with pytest.raises(Exception):
|
||||
learning_rate_schedule([1], epoch_size=0)
|
||||
|
||||
learning_rate_schedule([1], unit=UnitType.sample, epoch_size=0)
|
||||
|
||||
def test_momentum_schedule():
|
||||
m = 2500
|
||||
ms = momentum_as_time_constant_schedule([m])
|
||||
|
@ -38,7 +46,7 @@ def test_momentum_schedule():
|
|||
expected = np.exp(-1.0 / np.asarray(mlist))
|
||||
assert all(mi == ei for mi,ei in zip(msl,expected))
|
||||
|
||||
@pytest.mark.parametrize("params, expectation", SCHEDULE_PARAMS)
|
||||
@pytest.mark.parametrize("params, expectation", MOMENTUM_SCHEDULE_PARAMS)
|
||||
def test_momentum_schedule_per_sample(params, expectation):
|
||||
l = momentum_schedule(*params)
|
||||
assert [l[i] for i in range(len(expectation))] == expectation
|
||||
|
@ -51,16 +59,11 @@ def test_learner_init():
|
|||
|
||||
res = i * w
|
||||
|
||||
learner = sgd(res.parameters, lr=learning_rate_schedule(0.1, 10000))
|
||||
|
||||
#per-sample learning rate does not depend on the minibatch size
|
||||
learner = sgd(res.parameters, lr=learning_rate_schedule(0.1, UnitType.sample, 10000))
|
||||
assert learner.learning_rate() == 0.1
|
||||
assert learner.learning_rate(0) == 0.1
|
||||
assert learner.learning_rate(10) == 0.1
|
||||
|
||||
learner.reset_learning_rate(learning_rate_schedule([1,2,3], unit=UnitType.minibatch));
|
||||
assert learner.learning_rate(100) == 0.01
|
||||
assert learner.learning_rate(0) == 0.0
|
||||
learner.reset_learning_rate(learning_rate_schedule([1,2,3], UnitType.minibatch));
|
||||
assert learner.learning_rate() == 1.0
|
||||
|
||||
learner_parameter = learner.parameters
|
||||
from ..ops.variables import Parameter
|
||||
|
@ -68,18 +71,21 @@ def test_learner_init():
|
|||
assert isinstance(param, Parameter)
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
momentum_sgd(res.parameters, 0.1, momentum_time_constant)
|
||||
lr_per_sample = learning_rate_schedule(0.1, UnitType.sample)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant)
|
||||
|
||||
momentum_time_constant = momentum_schedule(momentum_time_constant) #should be ignored
|
||||
nesterov(res.parameters, lr=[0.1, 0.2], momentum=momentum_time_constant)
|
||||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample)
|
||||
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
|
||||
|
||||
adagrad(res.parameters, lr=[0.1]*3 +[0.2]*2 +[0.3], need_ave_multiplier=True)
|
||||
lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample)
|
||||
adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True)
|
||||
|
||||
momentum_time_constant = momentum_schedule(momentum_time_constant, unit=UnitType.minibatch) #should be ignored
|
||||
adam_sgd(res.parameters, lr=[(3,0.1), (2, 0.2), (1, 0.3)], momentum=momentum_time_constant)
|
||||
lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample)
|
||||
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
|
||||
|
||||
gamma, inc, dec, max, min = [0.1]*5
|
||||
rmsprop(res.parameters, learning_rate_schedule([0.1, 0.2], 100), gamma, inc, dec, max, min, True)
|
||||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)
|
||||
rmsprop(res.parameters, lr_per_sample, gamma, inc, dec, max, min, True)
|
||||
|
||||
def test_learner_update():
|
||||
i = input_variable(shape=(1,),
|
||||
|
@ -89,7 +95,7 @@ def test_learner_update():
|
|||
w = parameter(shape=(1,), init=w_init)
|
||||
res = i * w
|
||||
|
||||
learner = sgd(res.parameters, lr=[0.1]*50 + [0.2]*50)
|
||||
learner = sgd(res.parameters, lr=learning_rate_schedule([0.1]*50 + [0.2]*50, UnitType.sample))
|
||||
assert learner.learning_rate() == 0.1
|
||||
x = learner.update({w: np.asarray([[2.]], dtype=np.float32)}, 100)
|
||||
assert learner.learning_rate() == 0.2
|
||||
|
|
|
@ -20,10 +20,10 @@ def test_trainer(tmpdir):
|
|||
ce = cross_entropy_with_softmax(z, labels)
|
||||
errs = classification_error(z, labels)
|
||||
|
||||
m_schedule = momentum_schedule(1100)
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
trainer = Trainer(z, ce, errs, \
|
||||
[momentum_sgd(z.parameters, 0.007, m_schedule)])
|
||||
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)])
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
arguments = {in1: in1_value, labels: label_value}
|
||||
|
@ -49,10 +49,10 @@ def test_output_to_retain():
|
|||
ce = cross_entropy_with_softmax(z, labels)
|
||||
errs = classification_error(z, labels)
|
||||
|
||||
m_schedule = momentum_schedule(1100)
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
trainer = Trainer(z, ce, errs, \
|
||||
[momentum_sgd(z.parameters, 0.007, m_schedule)])
|
||||
[momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant)])
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
arguments = {in1: in1_value, labels: label_value}
|
||||
|
|
Загрузка…
Ссылка в новой задаче