Add a default (true) for the unit-gain flag value

This commit is contained in:
Alexey Reznichenko 2017-01-19 11:41:54 +01:00
Родитель cc7edb6d11
Коммит 9b6c6bde34
19 изменённых файлов: 88 добавлений и 32 удалений

Просмотреть файл

@ -73,8 +73,7 @@ def convnet_cifar10(debug_output=False):
l2_reg_weight = 0.002
# Instantiate the trainer object to drive the model training
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
l2_regularization_weight = l2_reg_weight)
trainer = cntk.Trainer(z, ce, pe, learner)

Просмотреть файл

@ -85,7 +85,6 @@ def convnet_cifar10_dataaug(reader_train, reader_test, max_epochs = 80):
# trainer object
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight = l2_reg_weight)
trainer = cntk.Trainer(z, ce, pe, learner)

Просмотреть файл

@ -103,7 +103,7 @@ def create_trainer(network, epoch_size, num_quantization_bits):
# Create learner
learner = data_parallel_distributed_learner(
cntk.learner.momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule, unit_gain=True, l2_regularization_weight=l2_reg_weight),
cntk.learner.momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight),
num_quantization_bits=num_quantization_bits,
distributed_after=0)

Просмотреть файл

@ -64,7 +64,7 @@ def convnet_mnist(debug_output=False):
mm_schedule = cntk.learner.momentum_as_time_constant_schedule(mm_time_constant, epoch_size)
# Instantiate the trainer object to drive the model training
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, unit_gain=True)
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule)
trainer = cntk.Trainer(z, ce, pe, learner)
# define mapping from reader streams to network inputs

Просмотреть файл

@ -88,7 +88,6 @@ def train_and_evaluate(reader_train, reader_test, network_name, max_epochs):
# trainer object
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule,
unit_gain = True,
l2_regularization_weight = l2_reg_weight)
trainer = Trainer(z, ce, pe, learner)

Просмотреть файл

@ -83,7 +83,7 @@ def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits):
# learner object
local_learner = momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule,
unit_gain = True, l2_regularization_weight = l2_reg_weight)
l2_regularization_weight = l2_reg_weight)
learner = data_parallel_distributed_learner(learner=local_learner,
num_quantization_bits=num_quantization_bits,

Просмотреть файл

@ -83,7 +83,6 @@ def train(reader, model, max_epochs):
lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size)
learner = adam_sgd(z.parameters,
lr=lr_per_sample, momentum=momentum_time_constant,
unit_gain=True,
low_memory=True,
gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)

Просмотреть файл

@ -158,7 +158,6 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
gradient_clipping_with_truncation = True
learner = momentum_sgd(z.parameters,
lr_per_minibatch, momentum_time_constant,
unit_gain=True,
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
trainer = Trainer(z, ce, errs, learner)

Просмотреть файл

@ -163,8 +163,7 @@ def train_lm(training_file):
momentum_time_constant = momentum_as_time_constant_schedule(1100)
clipping_threshold_per_sample = 5.0
gradient_clipping_with_truncation = True
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant,
unit_gain=True,
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant,
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
trainer = Trainer(z, ce, errs, learner)

Просмотреть файл

@ -3435,6 +3435,16 @@ namespace CNTK
bool gradientClippingWithTruncation = true;
};
///
/// Returns true if by default momentum is applied in the unit-gain fashion.
///
CNTK_API bool DefaultUnitGainValue();
///
/// Sets globally default unit-gain flag value.
///
CNTK_API void SetDefaultUnitGainValue(bool value);
///
/// Abstraction for learning a subset of parameters of a learnable Function using first order gradient values
/// For e.g momentum, AdaGrad, RMSProp etc. are different types of learners with their own algorithms for
@ -3545,7 +3555,7 @@ namespace CNTK
CNTK_API LearnerPtr MomentumSGDLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule,
bool unitGain,
bool unitGain = DefaultUnitGainValue(),
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
///
@ -3554,7 +3564,7 @@ namespace CNTK
CNTK_API LearnerPtr NesterovLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule,
bool unitGain,
bool unitGain = DefaultUnitGainValue(),
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
static MomentumSchedule DefaultVarianceMomentum = MomentumAsTimeConstantSchedule(2 * 3600 * 100);
@ -3565,7 +3575,7 @@ namespace CNTK
CNTK_API LearnerPtr AdamLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule,
bool unitGain,
bool unitGain = DefaultUnitGainValue(),
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
bool lowMemory = true,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());

Просмотреть файл

@ -528,4 +528,16 @@ namespace CNTK
{
return Microsoft::MSR::CNTK::CPUMatrix<float>::GetMaxNumThreads();
}
static std::atomic<bool> s_defaultUnitGainValue(true);
bool DefaultUnitGainValue()
{
return s_defaultUnitGainValue;
}
void SetDefaultUnitGainValue(bool value)
{
s_defaultUnitGainValue.store(value);
}
}

Просмотреть файл

@ -228,11 +228,23 @@ void TestTrainingParametersSchedule()
assert(schedule16[99999] == exp(-1.0 / 3.0));
}
void TestDefaultUnitGainGetterAndSetter()
{
assert(DefaultUnitGainValue());
SetDefaultUnitGainValue(false);
assert(!DefaultUnitGainValue());
SetDefaultUnitGainValue(true);
assert(DefaultUnitGainValue());
}
void LearnerTests()
{
fprintf(stderr, "\nLearnerTests..\n");
TestDefaultUnitGainGetterAndSetter();
TestTrainingParametersSchedule();
vector<DeviceDescriptor> devices{DeviceDescriptor::CPUDevice()};

Просмотреть файл

@ -271,7 +271,7 @@
" # Feel free to try other optimizers from \n",
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
" learner = adam_sgd(model.parameters,\n",
" lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True) \n",
" lr=lr_schedule, momentum=momentum_as_time_constant) \n",
" \n",
" # Instantiate the trainer\n",
" trainer = Trainer(model, loss, label_error, learner)\n",

Просмотреть файл

@ -324,7 +324,6 @@
" # trainer object\n",
" learner = momentum_sgd(z.parameters, \n",
" lr = lr_per_minibatch, momentum = momentum_time_constant, \n",
" unit_gain = True, \n",
" l2_regularization_weight=l2_reg_weight)\n",
" trainer = Trainer(z, ce, pe, [learner])\n",
"\n",

Просмотреть файл

@ -392,7 +392,6 @@
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
" learner = adam_sgd(criterion.parameters,\n",
" lr=lr_schedule, momentum=momentum_as_time_constant,\n",
" unit_gain = True, \n",
" low_memory=True,\n",
" gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)\n",
"\n",
@ -498,7 +497,7 @@
" lr_schedule = learning_rate_schedule(1, UnitType.minibatch)\n",
" momentum_as_time_constant = momentum_as_time_constant_schedule(0)\n",
" dummy_learner = adam_sgd(criterion.parameters, \n",
" lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True, low_memory=True)\n",
" lr=lr_schedule, momentum=momentum_as_time_constant, low_memory=True)\n",
" evaluator = Trainer(model, criterion.outputs[0], criterion.outputs[1], dummy_learner)\n",
" progress_printer = ProgressPrinter(tag='Evaluation')\n",
"\n",

Просмотреть файл

@ -753,7 +753,6 @@
"gradient_clipping_with_truncation = True\n",
"learner = momentum_sgd(model.parameters,\n",
" lr_per_sample, momentum_time_constant,\n",
" unit_gain = True, \n",
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
"trainer = Trainer(model, ce, errs, learner)"
@ -920,7 +919,6 @@
" gradient_clipping_with_truncation = True\n",
" learner = momentum_sgd(model.parameters,\n",
" lr_per_sample, momentum_time_constant,\n",
" unit_gain = True, \n",
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
" trainer = Trainer(model, ce, errs, learner)\n",

Просмотреть файл

@ -309,18 +309,18 @@
%ignore CNTK::MomentumSGDLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule,
bool unitGain,
bool unitGain = DefaultUnitGainValue(),
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
%ignore CNTK::NesterovLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule,
bool unitGain,
bool unitGain = DefaultUnitGainValue(),
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
%ignore CNTK::DefaultVarianceMomentum;
%ignore CNTK::AdamLearner(const std::vector<Parameter>& parameters,
const LearningRateSchedule& learningRateSchedule,
const MomentumSchedule& momentumSchedule,
bool unitGain,
bool unitGain = DefaultUnitGainValue(),
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
bool lowMemory = true,
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());

Просмотреть файл

@ -51,6 +51,18 @@ the following learning algorithms:
+------------------------+
'''
def default_unit_gain_value():
'''
Returns true if by default momentum is applied in the unit-gain fashion.
'''
return cntk_py.default_unit_gain_value()
def set_default_unit_gain_value(value):
'''
Sets globally default unit-gain flag value.
'''
cntk_py.set_default_unit_gain_value(value)
# an internal method to verify that the learning rate schedule
# has a proper (per-sample or per-MB schedule) type and raise
# an exception otherwise
@ -342,7 +354,7 @@ def sgd(parameters, lr,
return cntk_py.sgd_learner(parameters, lr, additional_options)
@typemap
def momentum_sgd(parameters, lr, momentum, unit_gain,
def momentum_sgd(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
gradient_clipping_with_truncation=True):
@ -357,7 +369,8 @@ def momentum_sgd(parameters, lr, momentum, unit_gain,
:func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -389,7 +402,7 @@ def momentum_sgd(parameters, lr, momentum, unit_gain,
additional_options)
@typemap
def nesterov(parameters, lr, momentum, unit_gain,
def nesterov(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
gradient_clipping_with_truncation=True):
@ -406,7 +419,8 @@ def nesterov(parameters, lr, momentum, unit_gain,
:func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
@ -496,7 +510,7 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
# TODO: unCamelCase and integrate upcoming CR
@typemap
def adam_sgd(parameters, lr, momentum, unit_gain,
def adam_sgd(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
variance_momentum = momentum_as_time_constant_schedule(720000),
low_memory=True,
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
@ -514,7 +528,8 @@ def adam_sgd(parameters, lr, momentum, unit_gain,
:func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
variance_momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
to ``momentum_as_time_constant_schedule(720000)``.

Просмотреть файл

@ -71,18 +71,35 @@ def test_learner_init():
param = learner_parameter[0]
assert isinstance(param, Parameter)
unit_gain_value = default_unit_gain_value()
assert unit_gain_value
momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.1, UnitType.sample)
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, True)
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant)
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, unit_gain_value)
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, unit_gain=unit_gain_value)
set_default_unit_gain_value(False)
unit_gain_value = default_unit_gain_value()
assert not unit_gain_value
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample)
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=False)
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
nesterov(res.parameters, lr_per_sample, momentum_time_constant, unit_gain_value)
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=unit_gain_value)
lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample)
adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True)
set_default_unit_gain_value(True)
unit_gain_value = default_unit_gain_value()
assert unit_gain_value
lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample)
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=True)
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
adam_sgd(res.parameters, lr_per_sample, momentum_time_constant, unit_gain_value)
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=unit_gain_value)
gamma, inc, dec, max, min = [0.1]*5
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)