Adding documentation for the training session
This commit is contained in:
Родитель
b7d4945a8e
Коммит
1241fab262
|
@ -9,29 +9,90 @@ from .device import use_default_device
|
|||
from .utils import sanitize_var_map, sanitize_function, typemap, value_to_seq
|
||||
from .io import _py_dict_to_cntk_dict
|
||||
|
||||
__doc__= '''\
|
||||
A training session encapsulates a typical training loop and binds together the minibatch source, the :doc:`trainer <cntk.trainer>` and checkpointing.
|
||||
__doc__ = '''\
|
||||
A training session encapsulates a typical training loop and binds together a minibatch source that is used for training, a :doc:`trainer <cntk.trainer>` and an optional cross validation minibatch source. A training session takes care of consistent checkpointing and progress printing with specified frequencies.
|
||||
'''
|
||||
|
||||
|
||||
class TrainingSession(cntk_py.TrainingSession):
|
||||
'''
|
||||
A training session is an abstraction that encapsulates a typical training loop given
|
||||
a minibatch source and a :doc:`trainer <cntk.trainer>` and takes care of checkpointing.
|
||||
The instance of the class should be created by using :func:`~cntk.training_session.training_session` function.
|
||||
|
||||
A training session trains a model using the specified ``trainer`` and the ``training_minibatch_source``
|
||||
where the minibatch size defined by ``mb_size_schedule``. The mapping between the input variables and the
|
||||
corresponding input streams should be specified using ``model_inputs_to_mb_source_mapping``.
|
||||
The size of the training set can be controlled either during creation of the training minibatch
|
||||
source or using ``max_training_samples`` parameter.
|
||||
Checkpointing is done both for the trainer and the training minibatch source.
|
||||
Progress printing happens each ``progress_frequency`` samples using the provided ``progress_printer``.
|
||||
|
||||
Args:
|
||||
training_minibatch_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for training
|
||||
trainer (:class:`~cntk.trainer.Trainer`): trainer
|
||||
mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for training
|
||||
progress_printer (:class:`~cntk.utils.progress_print.ProgressPrinter`): progress printer
|
||||
model_inputs_to_mb_source_mapping (dict): mapping between input variables and input streams
|
||||
checkpoint_frequency (int): checkpoint frequency in samples. If 0, no checkpointing takes place.
|
||||
If ``sys.maxsize``, a single checkpoint is taken at the end of the training.
|
||||
checkpoint_filename (str): checkpoint file name.
|
||||
save_all_checkpoints (bool): saves all checkpoints, using ``checkpoint_filename`` as prefix and checkpoint index as a suffix.
|
||||
restore (bool): flag, indicating whether to restore from available checkpoint before the start of the training
|
||||
progress_frequency (int): frequency in samples for aggregated progress printing
|
||||
cv_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for cross validation
|
||||
cv_frequency (int): frequency in samples for cross validation
|
||||
If ``sys.maxsize``, a single cross validation is performed at the end of training.
|
||||
cv_mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for cross validation
|
||||
max_training_samples (int): maximum number of samples used for training
|
||||
'''
|
||||
|
||||
def __init__(self, training_minibatch_source, trainer, mb_size_schedule,
|
||||
progress_printer, model_inputs_to_mb_source_mapping,
|
||||
checkpoint_frequency, checkpoint_filename, save_all_checkpoints,
|
||||
progress_printer, model_inputs_to_mb_source_mapping,
|
||||
checkpoint_frequency, checkpoint_filename, save_all_checkpoints,
|
||||
restore, progress_frequency, cv_source, cv_frequency, cv_mb_size_schedule, max_training_samples):
|
||||
|
||||
self.progress_printer = progress_printer
|
||||
self.trainer=trainer
|
||||
self.trainer = trainer
|
||||
|
||||
super(TrainingSession, self).__init__ (
|
||||
training_minibatch_source,
|
||||
trainer,
|
||||
model_inputs_to_mb_source_mapping,
|
||||
mb_size_schedule,
|
||||
checkpoint_frequency,
|
||||
if not isinstance(mb_size_schedule, cntk_py.minibatch_size_schedule):
|
||||
raise ValueError('mb_size_schedule type (%s) not supported. '
|
||||
'mb_size_schedule must be a schedule '
|
||||
'(output of minibatch_size_schedule() function)'
|
||||
% type(mb_size_schedule))
|
||||
|
||||
if checkpoint_filename is None:
|
||||
if checkpoint_frequency is not None and checkpoint_frequency != 0:
|
||||
raise ValueError(
|
||||
"Checkpoint frequency cannot be specified without checkpoint_filename")
|
||||
checkpoint_frequency = 0
|
||||
checkpoint_filename = ""
|
||||
|
||||
if progress_frequency is None:
|
||||
progress_frequency = sys.maxsize
|
||||
|
||||
if cv_source is None:
|
||||
if cv_frequency is not None and cv_frequency != 0:
|
||||
raise ValueError(
|
||||
"Cross validation frequency cannot be specified without cross validation minibatch source")
|
||||
cv_frequency = 0
|
||||
|
||||
if cv_frequency is None:
|
||||
cv_frequency = sys.maxsize
|
||||
|
||||
if max_training_samples is None:
|
||||
max_training_samples = sys.maxsize
|
||||
|
||||
if checkpoint_frequency is None:
|
||||
checkpoint_frequency = sys.maxsize
|
||||
|
||||
if cv_mb_size_schedule is None:
|
||||
cv_mb_size_schedule = minibatch_size_schedule(1)
|
||||
|
||||
super(TrainingSession, self).__init__(
|
||||
training_minibatch_source,
|
||||
trainer,
|
||||
model_inputs_to_mb_source_mapping,
|
||||
mb_size_schedule,
|
||||
checkpoint_frequency,
|
||||
checkpoint_filename,
|
||||
cv_source,
|
||||
cv_mb_size_schedule,
|
||||
|
@ -44,7 +105,11 @@ class TrainingSession(cntk_py.TrainingSession):
|
|||
@typemap
|
||||
def train(self, device=None):
|
||||
'''
|
||||
Performs training.
|
||||
Perform training on a specified device.
|
||||
|
||||
Args:
|
||||
device (:class:~cntk.device.DeviceDescriptor): the device descriptor containing
|
||||
the type and id of the device where training takes place.
|
||||
'''
|
||||
|
||||
if not device:
|
||||
|
@ -53,18 +118,39 @@ class TrainingSession(cntk_py.TrainingSession):
|
|||
super(TrainingSession, self).train(device)
|
||||
|
||||
def on_minibatch_end(self):
|
||||
'''
|
||||
Callback that gets executed at the end of each minibatch.
|
||||
'''
|
||||
if self.progress_printer and self.trainer.total_number_of_samples_seen != 0:
|
||||
self.progress_printer.update_with_trainer(self.trainer, with_metric=True)
|
||||
self.progress_printer.update_with_trainer(
|
||||
self.trainer, with_metric=True)
|
||||
|
||||
def on_progress(self, index):
|
||||
'''
|
||||
Callback that gets executed with the ``progress_frequency`` frequency in samples.
|
||||
|
||||
Args:
|
||||
index (int): index of the current callback.
|
||||
'''
|
||||
if self.progress_printer:
|
||||
self.progress_printer.epoch_summary(with_metric=True)
|
||||
|
||||
def on_cross_validation_end(self, index, average_error, num_samples, num_minibatches):
|
||||
'''
|
||||
Callback that gets executed at the end of cross validation.
|
||||
|
||||
Args:
|
||||
index (int): index of the current callback.
|
||||
average_error (float): average error for the cross validation
|
||||
num_samples (int): number of samples in cross validation
|
||||
num_minibatches (int): number of minibatch in cross validation
|
||||
'''
|
||||
if self.progress_printer:
|
||||
msg = "Cross Validation [{}]: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(index + 1, num_minibatches, average_error * 100, num_samples)
|
||||
msg = "Cross Validation [{}]: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(
|
||||
index + 1, num_minibatches, average_error * 100, num_samples)
|
||||
self.progress_printer.log(msg)
|
||||
|
||||
|
||||
@typemap
|
||||
def minibatch_size_schedule(schedule, epoch_size=1):
|
||||
'''
|
||||
|
@ -99,95 +185,63 @@ def minibatch_size_schedule(schedule, epoch_size=1):
|
|||
if isinstance(schedule, int):
|
||||
if epoch_size != 1:
|
||||
raise ValueError('when providing the schedule as a number,'
|
||||
' epoch_size is ignored')
|
||||
' epoch_size is ignored')
|
||||
return cntk_py.minibatch_size_schedule(schedule)
|
||||
|
||||
if isinstance(schedule, list):
|
||||
return cntk_py.minibatch_size_schedule(schedule, epoch_size)
|
||||
|
||||
raise ValueError('schedule must be either a float or a list, not %s'%type(schedule))
|
||||
raise ValueError(
|
||||
'schedule must be either a float or a list, not %s' % type(schedule))
|
||||
|
||||
|
||||
@typemap
|
||||
def training_session(training_minibatch_source,
|
||||
trainer, mb_size_schedule,
|
||||
progress_printer = None,
|
||||
model_inputs_to_mb_source_mapping = {},
|
||||
checkpoint_filename = None,
|
||||
checkpoint_frequency = None,
|
||||
save_all_checkpoints = False,
|
||||
restore = True,
|
||||
progress_frequency = None,
|
||||
cv_source = None,
|
||||
cv_mb_size_schedule = None,
|
||||
cv_frequency = None,
|
||||
max_training_samples = None):
|
||||
progress_printer=None,
|
||||
model_inputs_to_mb_source_mapping={},
|
||||
checkpoint_filename=None,
|
||||
checkpoint_frequency=None,
|
||||
save_all_checkpoints=False,
|
||||
restore=True,
|
||||
progress_frequency=None,
|
||||
cv_source=None,
|
||||
cv_mb_size_schedule=None,
|
||||
cv_frequency=None,
|
||||
max_training_samples=None):
|
||||
'''
|
||||
Creates a basic training session.
|
||||
A factory function to create a training session object.
|
||||
|
||||
Args:
|
||||
training_minibatch_source: a minibatch source that will be used for training.
|
||||
trainer: a Trainer.
|
||||
mb_size_schedule: a minibatch size schedule for training. Created using :func:`minibatch_size_schedule`
|
||||
progress_printer: a progress printer instance
|
||||
model_inputs_to_mb_source_mapping: mapping between the input node names of the model and the stream
|
||||
names provided from the minibatch source. By default all streams are taken with their respective names.
|
||||
checkpoint_filename: a file name of the checkpoint file, if None, the checkpointing is disabled.
|
||||
checkpoint_frequency: an approximate number of global samples processed accross the workers
|
||||
after which the checkpoint is taken. Should be positive number if the checkpoint file is specified.
|
||||
save_all_checkpoints: flag, indicating whether to store all checkpoints, by default only the last checkpoint is preserved
|
||||
restore: flag, indicating whether perform restore of the training session from the checkpoint before the start of the training
|
||||
progress_frequency: an approximate number of global samples processed accross the workers
|
||||
after which the summary of metrics is reported using the progress_printer
|
||||
cv_source: a minibatch source that will be used for cross validation.
|
||||
cv_mb_size_schedule: a minibatch size schedule for cross validation. Created using :func:`minibatch_size_schedule`
|
||||
progress_frequency: an approximate number of global samples processed accross the workers
|
||||
after which the cross validation takes place
|
||||
max_training_samples: max number of samples after which the training should be stopped
|
||||
Args:
|
||||
training_minibatch_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for training
|
||||
trainer (:class:`~cntk.trainer.Trainer`): trainer
|
||||
mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for training
|
||||
progress_printer (:class:`~cntk.utils.progress_print.ProgressPrinter`): progress printer
|
||||
model_inputs_to_mb_source_mapping (dict): mapping between input variables and input streams
|
||||
checkpoint_filename (str): checkpoint file name.
|
||||
checkpoint_frequency (int): checkpoint frequency in samples. If 0, no checkpointing takes place.
|
||||
If ``sys.maxsize``, a single checkpoint is taken at the end of the training.
|
||||
save_all_checkpoints (bool): saves all checkpoints, using ``checkpoint_filename`` as prefix and checkpoint index as a suffix.
|
||||
restore (bool): flag, indicating whether to restore from available checkpoint before the start of the training
|
||||
progress_frequency (int): frequency in samples for aggregated progress printing
|
||||
cv_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for cross validation
|
||||
cv_frequency (int): frequency in samples for cross validation
|
||||
cv_mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for cross validation
|
||||
If ``sys.maxsize``, a single cross validation is performed at the end of training.
|
||||
max_training_samples (int): maximum number of samples used for training
|
||||
|
||||
Returns:
|
||||
Instance of a :class:`TrainingSession`
|
||||
Instance of :class:`~TrainingSession`
|
||||
'''
|
||||
if not isinstance(mb_size_schedule, cntk_py.minibatch_size_schedule):
|
||||
raise ValueError('mb_size_schedule type (%s) not supported. '
|
||||
'mb_size_schedule must be a schedule '
|
||||
'(output of minibatch_size_schedule() function)'
|
||||
% type(mb_size_schedule))
|
||||
|
||||
if checkpoint_filename is None:
|
||||
if checkpoint_frequency is not None and checkpoint_frequency != 0:
|
||||
raise ValueError("Checkpoint frequency cannot be specified without checkpoint_filename")
|
||||
checkpoint_frequency = 0
|
||||
checkpoint_filename=""
|
||||
|
||||
if progress_frequency is None:
|
||||
progress_frequency = sys.maxsize
|
||||
|
||||
if cv_source is None:
|
||||
if cv_frequency is not None and cv_frequency != 0:
|
||||
raise ValueError("Cross validation frequency cannot be specified without cross validation minibatch source")
|
||||
cv_frequency = 0
|
||||
|
||||
if cv_frequency is None:
|
||||
cv_frequency = sys.maxsize
|
||||
|
||||
if max_training_samples is None:
|
||||
max_training_samples = sys.maxsize
|
||||
|
||||
if checkpoint_frequency is None:
|
||||
checkpoint_frequency = sys.maxsize
|
||||
|
||||
if cv_mb_size_schedule is None:
|
||||
cv_mb_size_schedule = minibatch_size_schedule(1)
|
||||
|
||||
return TrainingSession(training_minibatch_source, trainer,
|
||||
mb_size_schedule, progress_printer,
|
||||
model_inputs_to_mb_source_mapping,
|
||||
return TrainingSession(training_minibatch_source, trainer,
|
||||
mb_size_schedule, progress_printer,
|
||||
model_inputs_to_mb_source_mapping,
|
||||
checkpoint_frequency,
|
||||
checkpoint_filename,
|
||||
save_all_checkpoints=save_all_checkpoints,
|
||||
restore=restore,
|
||||
progress_frequency=progress_frequency,
|
||||
cv_source=cv_source,
|
||||
cv_mb_size_schedule=cv_mb_size_schedule,
|
||||
cv_frequency=cv_frequency,
|
||||
cv_mb_size_schedule=cv_mb_size_schedule,
|
||||
max_training_samples=max_training_samples)
|
||||
|
|
|
@ -6,7 +6,9 @@ Python API Reference
|
|||
|
||||
Graph components <graph>
|
||||
IO <cntk.io>
|
||||
Trainer & learners <cntk.trainer>
|
||||
Learner <cntk.learner>
|
||||
Trainer <cntk.trainer>
|
||||
Training session <cntk.training_session>
|
||||
Operators <cntk.ops>
|
||||
Utils <cntk.utils>
|
||||
Module reference <modules>
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
cntk.training_session package
|
||||
=============================
|
||||
|
||||
.. automodule:: cntk.training_session
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
Загрузка…
Ссылка в новой задаче