Add a default (true) for the unit-gain flag value
This commit is contained in:
Родитель
cc7edb6d11
Коммит
9b6c6bde34
|
@ -73,8 +73,7 @@ def convnet_cifar10(debug_output=False):
|
|||
l2_reg_weight = 0.002
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -85,7 +85,6 @@ def convnet_cifar10_dataaug(reader_train, reader_test, max_epochs = 80):
|
|||
|
||||
# trainer object
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ def create_trainer(network, epoch_size, num_quantization_bits):
|
|||
|
||||
# Create learner
|
||||
learner = data_parallel_distributed_learner(
|
||||
cntk.learner.momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule, unit_gain=True, l2_regularization_weight=l2_reg_weight),
|
||||
cntk.learner.momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight),
|
||||
num_quantization_bits=num_quantization_bits,
|
||||
distributed_after=0)
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ def convnet_mnist(debug_output=False):
|
|||
mm_schedule = cntk.learner.momentum_as_time_constant_schedule(mm_time_constant, epoch_size)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule, unit_gain=True)
|
||||
learner = cntk.learner.momentum_sgd(z.parameters, lr_schedule, mm_schedule)
|
||||
trainer = cntk.Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
|
|
|
@ -88,7 +88,6 @@ def train_and_evaluate(reader_train, reader_test, network_name, max_epochs):
|
|||
|
||||
# trainer object
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits):
|
|||
|
||||
# learner object
|
||||
local_learner = momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule,
|
||||
unit_gain = True, l2_regularization_weight = l2_reg_weight)
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
|
||||
learner = data_parallel_distributed_learner(learner=local_learner,
|
||||
num_quantization_bits=num_quantization_bits,
|
||||
|
|
|
@ -83,7 +83,6 @@ def train(reader, model, max_epochs):
|
|||
lr_per_sample = learning_rate_schedule(lr_schedule, UnitType.sample, epoch_size)
|
||||
learner = adam_sgd(z.parameters,
|
||||
lr=lr_per_sample, momentum=momentum_time_constant,
|
||||
unit_gain=True,
|
||||
low_memory=True,
|
||||
gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)
|
||||
|
||||
|
|
|
@ -158,7 +158,6 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
gradient_clipping_with_truncation = True
|
||||
learner = momentum_sgd(z.parameters,
|
||||
lr_per_minibatch, momentum_time_constant,
|
||||
unit_gain=True,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
|
|
@ -163,8 +163,7 @@ def train_lm(training_file):
|
|||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
clipping_threshold_per_sample = 5.0
|
||||
gradient_clipping_with_truncation = True
|
||||
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant,
|
||||
unit_gain=True,
|
||||
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
|
|
@ -3435,6 +3435,16 @@ namespace CNTK
|
|||
bool gradientClippingWithTruncation = true;
|
||||
};
|
||||
|
||||
///
|
||||
/// Returns true if by default momentum is applied in the unit-gain fashion.
|
||||
///
|
||||
CNTK_API bool DefaultUnitGainValue();
|
||||
|
||||
///
|
||||
/// Sets globally default unit-gain flag value.
|
||||
///
|
||||
CNTK_API void SetDefaultUnitGainValue(bool value);
|
||||
|
||||
///
|
||||
/// Abstraction for learning a subset of parameters of a learnable Function using first order gradient values
|
||||
/// For e.g momentum, AdaGrad, RMSProp etc. are different types of learners with their own algorithms for
|
||||
|
@ -3545,7 +3555,7 @@ namespace CNTK
|
|||
CNTK_API LearnerPtr MomentumSGDLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
bool unitGain = DefaultUnitGainValue(),
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
||||
///
|
||||
|
@ -3554,7 +3564,7 @@ namespace CNTK
|
|||
CNTK_API LearnerPtr NesterovLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
bool unitGain = DefaultUnitGainValue(),
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
||||
static MomentumSchedule DefaultVarianceMomentum = MomentumAsTimeConstantSchedule(2 * 3600 * 100);
|
||||
|
@ -3565,7 +3575,7 @@ namespace CNTK
|
|||
CNTK_API LearnerPtr AdamLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
bool unitGain = DefaultUnitGainValue(),
|
||||
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
|
||||
bool lowMemory = true,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
|
|
@ -528,4 +528,16 @@ namespace CNTK
|
|||
{
|
||||
return Microsoft::MSR::CNTK::CPUMatrix<float>::GetMaxNumThreads();
|
||||
}
|
||||
|
||||
static std::atomic<bool> s_defaultUnitGainValue(true);
|
||||
|
||||
bool DefaultUnitGainValue()
|
||||
{
|
||||
return s_defaultUnitGainValue;
|
||||
}
|
||||
|
||||
void SetDefaultUnitGainValue(bool value)
|
||||
{
|
||||
s_defaultUnitGainValue.store(value);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -228,11 +228,23 @@ void TestTrainingParametersSchedule()
|
|||
assert(schedule16[99999] == exp(-1.0 / 3.0));
|
||||
}
|
||||
|
||||
void TestDefaultUnitGainGetterAndSetter()
|
||||
{
|
||||
assert(DefaultUnitGainValue());
|
||||
|
||||
SetDefaultUnitGainValue(false);
|
||||
assert(!DefaultUnitGainValue());
|
||||
|
||||
SetDefaultUnitGainValue(true);
|
||||
assert(DefaultUnitGainValue());
|
||||
}
|
||||
|
||||
void LearnerTests()
|
||||
{
|
||||
fprintf(stderr, "\nLearnerTests..\n");
|
||||
|
||||
TestDefaultUnitGainGetterAndSetter();
|
||||
|
||||
TestTrainingParametersSchedule();
|
||||
|
||||
vector<DeviceDescriptor> devices{DeviceDescriptor::CPUDevice()};
|
||||
|
|
|
@ -271,7 +271,7 @@
|
|||
" # Feel free to try other optimizers from \n",
|
||||
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
|
||||
" learner = adam_sgd(model.parameters,\n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True) \n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant) \n",
|
||||
" \n",
|
||||
" # Instantiate the trainer\n",
|
||||
" trainer = Trainer(model, loss, label_error, learner)\n",
|
||||
|
|
|
@ -324,7 +324,6 @@
|
|||
" # trainer object\n",
|
||||
" learner = momentum_sgd(z.parameters, \n",
|
||||
" lr = lr_per_minibatch, momentum = momentum_time_constant, \n",
|
||||
" unit_gain = True, \n",
|
||||
" l2_regularization_weight=l2_reg_weight)\n",
|
||||
" trainer = Trainer(z, ce, pe, [learner])\n",
|
||||
"\n",
|
||||
|
|
|
@ -392,7 +392,6 @@
|
|||
" # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n",
|
||||
" learner = adam_sgd(criterion.parameters,\n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant,\n",
|
||||
" unit_gain = True, \n",
|
||||
" low_memory=True,\n",
|
||||
" gradient_clipping_threshold_per_sample=15, gradient_clipping_with_truncation=True)\n",
|
||||
"\n",
|
||||
|
@ -498,7 +497,7 @@
|
|||
" lr_schedule = learning_rate_schedule(1, UnitType.minibatch)\n",
|
||||
" momentum_as_time_constant = momentum_as_time_constant_schedule(0)\n",
|
||||
" dummy_learner = adam_sgd(criterion.parameters, \n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, unit_gain=True, low_memory=True)\n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, low_memory=True)\n",
|
||||
" evaluator = Trainer(model, criterion.outputs[0], criterion.outputs[1], dummy_learner)\n",
|
||||
" progress_printer = ProgressPrinter(tag='Evaluation')\n",
|
||||
"\n",
|
||||
|
|
|
@ -753,7 +753,6 @@
|
|||
"gradient_clipping_with_truncation = True\n",
|
||||
"learner = momentum_sgd(model.parameters,\n",
|
||||
" lr_per_sample, momentum_time_constant,\n",
|
||||
" unit_gain = True, \n",
|
||||
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
|
||||
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
|
||||
"trainer = Trainer(model, ce, errs, learner)"
|
||||
|
@ -920,7 +919,6 @@
|
|||
" gradient_clipping_with_truncation = True\n",
|
||||
" learner = momentum_sgd(model.parameters,\n",
|
||||
" lr_per_sample, momentum_time_constant,\n",
|
||||
" unit_gain = True, \n",
|
||||
" gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,\n",
|
||||
" gradient_clipping_with_truncation=gradient_clipping_with_truncation)\n",
|
||||
" trainer = Trainer(model, ce, errs, learner)\n",
|
||||
|
|
|
@ -309,18 +309,18 @@
|
|||
%ignore CNTK::MomentumSGDLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
bool unitGain = DefaultUnitGainValue(),
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
%ignore CNTK::NesterovLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
bool unitGain = DefaultUnitGainValue(),
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
%ignore CNTK::DefaultVarianceMomentum;
|
||||
%ignore CNTK::AdamLearner(const std::vector<Parameter>& parameters,
|
||||
const LearningRateSchedule& learningRateSchedule,
|
||||
const MomentumSchedule& momentumSchedule,
|
||||
bool unitGain,
|
||||
bool unitGain = DefaultUnitGainValue(),
|
||||
const MomentumSchedule& varianceMomentumSchedule = DefaultVarianceMomentum,
|
||||
bool lowMemory = true,
|
||||
AdditionalLearningOptions additionalOptions = AdditionalLearningOptions());
|
||||
|
|
|
@ -51,6 +51,18 @@ the following learning algorithms:
|
|||
+------------------------+
|
||||
'''
|
||||
|
||||
def default_unit_gain_value():
|
||||
'''
|
||||
Returns true if by default momentum is applied in the unit-gain fashion.
|
||||
'''
|
||||
return cntk_py.default_unit_gain_value()
|
||||
|
||||
def set_default_unit_gain_value(value):
|
||||
'''
|
||||
Sets globally default unit-gain flag value.
|
||||
'''
|
||||
cntk_py.set_default_unit_gain_value(value)
|
||||
|
||||
# an internal method to verify that the learning rate schedule
|
||||
# has a proper (per-sample or per-MB schedule) type and raise
|
||||
# an exception otherwise
|
||||
|
@ -342,7 +354,7 @@ def sgd(parameters, lr,
|
|||
return cntk_py.sgd_learner(parameters, lr, additional_options)
|
||||
|
||||
@typemap
|
||||
def momentum_sgd(parameters, lr, momentum, unit_gain,
|
||||
def momentum_sgd(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
|
||||
gradient_clipping_with_truncation=True):
|
||||
|
@ -357,7 +369,8 @@ def momentum_sgd(parameters, lr, momentum, unit_gain,
|
|||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
|
||||
to the value returned by :func:`default_unit_gain_value`.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -389,7 +402,7 @@ def momentum_sgd(parameters, lr, momentum, unit_gain,
|
|||
additional_options)
|
||||
|
||||
@typemap
|
||||
def nesterov(parameters, lr, momentum, unit_gain,
|
||||
def nesterov(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
gaussian_noise_injection_std_dev=0.0, gradient_clipping_threshold_per_sample=np.inf,
|
||||
gradient_clipping_with_truncation=True):
|
||||
|
@ -406,7 +419,8 @@ def nesterov(parameters, lr, momentum, unit_gain,
|
|||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
|
||||
to the value returned by :func:`default_unit_gain_value`.
|
||||
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
|
||||
defaults to 0.0
|
||||
l2_regularization_weight (float, optional): the L2 regularization weight per sample,
|
||||
|
@ -496,7 +510,7 @@ def adagrad(parameters, lr, need_ave_multiplier=True,
|
|||
|
||||
# TODO: unCamelCase and integrate upcoming CR
|
||||
@typemap
|
||||
def adam_sgd(parameters, lr, momentum, unit_gain,
|
||||
def adam_sgd(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
|
||||
variance_momentum = momentum_as_time_constant_schedule(720000),
|
||||
low_memory=True,
|
||||
l1_regularization_weight=0.0, l2_regularization_weight=0.0,
|
||||
|
@ -514,7 +528,8 @@ def adam_sgd(parameters, lr, momentum, unit_gain,
|
|||
:func:`momentum_as_time_constant_schedule`): momentum schedule.
|
||||
For additional information, please refer to the `wiki
|
||||
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter.
|
||||
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
|
||||
to the value returned by :func:`default_unit_gain_value`.
|
||||
variance_momentum (output of :func:`momentum_schedule` or
|
||||
:func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
|
||||
to ``momentum_as_time_constant_schedule(720000)``.
|
||||
|
|
|
@ -71,18 +71,35 @@ def test_learner_init():
|
|||
param = learner_parameter[0]
|
||||
assert isinstance(param, Parameter)
|
||||
|
||||
unit_gain_value = default_unit_gain_value()
|
||||
assert unit_gain_value
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.1, UnitType.sample)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, True)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, unit_gain_value)
|
||||
momentum_sgd(res.parameters, lr_per_sample, momentum_time_constant, unit_gain=unit_gain_value)
|
||||
|
||||
set_default_unit_gain_value(False)
|
||||
unit_gain_value = default_unit_gain_value()
|
||||
assert not unit_gain_value
|
||||
|
||||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample)
|
||||
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=False)
|
||||
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
|
||||
nesterov(res.parameters, lr_per_sample, momentum_time_constant, unit_gain_value)
|
||||
nesterov(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=unit_gain_value)
|
||||
|
||||
lr_per_sample = learning_rate_schedule([0.1]*3 +[0.2]*2 +[0.3], UnitType.sample)
|
||||
adagrad(res.parameters, lr=lr_per_sample, need_ave_multiplier=True)
|
||||
|
||||
set_default_unit_gain_value(True)
|
||||
unit_gain_value = default_unit_gain_value()
|
||||
assert unit_gain_value
|
||||
|
||||
lr_per_sample = learning_rate_schedule([(3,0.1), (2, 0.2), (1, 0.3)], UnitType.sample)
|
||||
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=True)
|
||||
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant)
|
||||
adam_sgd(res.parameters, lr_per_sample, momentum_time_constant, unit_gain_value)
|
||||
adam_sgd(res.parameters, lr=lr_per_sample, momentum=momentum_time_constant, unit_gain=unit_gain_value)
|
||||
|
||||
gamma, inc, dec, max, min = [0.1]*5
|
||||
lr_per_sample = learning_rate_schedule([0.1, 0.2], UnitType.sample, 100)
|
||||
|
|
Загрузка…
Ссылка в новой задаче