Adding documentation for the training session

This commit is contained in:
Eldar Akchurin 2017-02-09 11:16:51 +01:00
Родитель b7d4945a8e
Коммит 1241fab262
3 изменённых файлов: 150 добавлений и 87 удалений

Просмотреть файл

@ -9,29 +9,90 @@ from .device import use_default_device
from .utils import sanitize_var_map, sanitize_function, typemap, value_to_seq
from .io import _py_dict_to_cntk_dict
__doc__= '''\
A training session encapsulates a typical training loop and binds together the minibatch source, the :doc:`trainer <cntk.trainer>` and checkpointing.
__doc__ = '''\
A training session encapsulates a typical training loop and binds together a minibatch source that is used for training, a :doc:`trainer <cntk.trainer>` and an optional cross validation minibatch source. A training session takes care of consistent checkpointing and progress printing with specified frequencies.
'''
class TrainingSession(cntk_py.TrainingSession):
'''
A training session is an abstraction that encapsulates a typical training loop given
a minibatch source and a :doc:`trainer <cntk.trainer>` and takes care of checkpointing.
The instance of the class should be created by using :func:`~cntk.training_session.training_session` function.
A training session trains a model using the specified ``trainer`` and the ``training_minibatch_source``
where the minibatch size defined by ``mb_size_schedule``. The mapping between the input variables and the
corresponding input streams should be specified using ``model_inputs_to_mb_source_mapping``.
The size of the training set can be controlled either during creation of the training minibatch
source or using ``max_training_samples`` parameter.
Checkpointing is done both for the trainer and the training minibatch source.
Progress printing happens each ``progress_frequency`` samples using the provided ``progress_printer``.
Args:
training_minibatch_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for training
trainer (:class:`~cntk.trainer.Trainer`): trainer
mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for training
progress_printer (:class:`~cntk.utils.progress_print.ProgressPrinter`): progress printer
model_inputs_to_mb_source_mapping (dict): mapping between input variables and input streams
checkpoint_frequency (int): checkpoint frequency in samples. If 0, no checkpointing takes place.
If ``sys.maxsize``, a single checkpoint is taken at the end of the training.
checkpoint_filename (str): checkpoint file name.
save_all_checkpoints (bool): saves all checkpoints, using ``checkpoint_filename`` as prefix and checkpoint index as a suffix.
restore (bool): flag, indicating whether to restore from available checkpoint before the start of the training
progress_frequency (int): frequency in samples for aggregated progress printing
cv_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for cross validation
cv_frequency (int): frequency in samples for cross validation
If ``sys.maxsize``, a single cross validation is performed at the end of training.
cv_mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for cross validation
max_training_samples (int): maximum number of samples used for training
'''
def __init__(self, training_minibatch_source, trainer, mb_size_schedule,
progress_printer, model_inputs_to_mb_source_mapping,
checkpoint_frequency, checkpoint_filename, save_all_checkpoints,
progress_printer, model_inputs_to_mb_source_mapping,
checkpoint_frequency, checkpoint_filename, save_all_checkpoints,
restore, progress_frequency, cv_source, cv_frequency, cv_mb_size_schedule, max_training_samples):
self.progress_printer = progress_printer
self.trainer=trainer
self.trainer = trainer
super(TrainingSession, self).__init__ (
training_minibatch_source,
trainer,
model_inputs_to_mb_source_mapping,
mb_size_schedule,
checkpoint_frequency,
if not isinstance(mb_size_schedule, cntk_py.minibatch_size_schedule):
raise ValueError('mb_size_schedule type (%s) not supported. '
'mb_size_schedule must be a schedule '
'(output of minibatch_size_schedule() function)'
% type(mb_size_schedule))
if checkpoint_filename is None:
if checkpoint_frequency is not None and checkpoint_frequency != 0:
raise ValueError(
"Checkpoint frequency cannot be specified without checkpoint_filename")
checkpoint_frequency = 0
checkpoint_filename = ""
if progress_frequency is None:
progress_frequency = sys.maxsize
if cv_source is None:
if cv_frequency is not None and cv_frequency != 0:
raise ValueError(
"Cross validation frequency cannot be specified without cross validation minibatch source")
cv_frequency = 0
if cv_frequency is None:
cv_frequency = sys.maxsize
if max_training_samples is None:
max_training_samples = sys.maxsize
if checkpoint_frequency is None:
checkpoint_frequency = sys.maxsize
if cv_mb_size_schedule is None:
cv_mb_size_schedule = minibatch_size_schedule(1)
super(TrainingSession, self).__init__(
training_minibatch_source,
trainer,
model_inputs_to_mb_source_mapping,
mb_size_schedule,
checkpoint_frequency,
checkpoint_filename,
cv_source,
cv_mb_size_schedule,
@ -44,7 +105,11 @@ class TrainingSession(cntk_py.TrainingSession):
@typemap
def train(self, device=None):
'''
Performs training.
Perform training on a specified device.
Args:
device (:class:~cntk.device.DeviceDescriptor): the device descriptor containing
the type and id of the device where training takes place.
'''
if not device:
@ -53,18 +118,39 @@ class TrainingSession(cntk_py.TrainingSession):
super(TrainingSession, self).train(device)
def on_minibatch_end(self):
'''
Callback that gets executed at the end of each minibatch.
'''
if self.progress_printer and self.trainer.total_number_of_samples_seen != 0:
self.progress_printer.update_with_trainer(self.trainer, with_metric=True)
self.progress_printer.update_with_trainer(
self.trainer, with_metric=True)
def on_progress(self, index):
'''
Callback that gets executed with the ``progress_frequency`` frequency in samples.
Args:
index (int): index of the current callback.
'''
if self.progress_printer:
self.progress_printer.epoch_summary(with_metric=True)
def on_cross_validation_end(self, index, average_error, num_samples, num_minibatches):
'''
Callback that gets executed at the end of cross validation.
Args:
index (int): index of the current callback.
average_error (float): average error for the cross validation
num_samples (int): number of samples in cross validation
num_minibatches (int): number of minibatch in cross validation
'''
if self.progress_printer:
msg = "Cross Validation [{}]: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(index + 1, num_minibatches, average_error * 100, num_samples)
msg = "Cross Validation [{}]: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(
index + 1, num_minibatches, average_error * 100, num_samples)
self.progress_printer.log(msg)
@typemap
def minibatch_size_schedule(schedule, epoch_size=1):
'''
@ -99,95 +185,63 @@ def minibatch_size_schedule(schedule, epoch_size=1):
if isinstance(schedule, int):
if epoch_size != 1:
raise ValueError('when providing the schedule as a number,'
' epoch_size is ignored')
' epoch_size is ignored')
return cntk_py.minibatch_size_schedule(schedule)
if isinstance(schedule, list):
return cntk_py.minibatch_size_schedule(schedule, epoch_size)
raise ValueError('schedule must be either a float or a list, not %s'%type(schedule))
raise ValueError(
'schedule must be either a float or a list, not %s' % type(schedule))
@typemap
def training_session(training_minibatch_source,
trainer, mb_size_schedule,
progress_printer = None,
model_inputs_to_mb_source_mapping = {},
checkpoint_filename = None,
checkpoint_frequency = None,
save_all_checkpoints = False,
restore = True,
progress_frequency = None,
cv_source = None,
cv_mb_size_schedule = None,
cv_frequency = None,
max_training_samples = None):
progress_printer=None,
model_inputs_to_mb_source_mapping={},
checkpoint_filename=None,
checkpoint_frequency=None,
save_all_checkpoints=False,
restore=True,
progress_frequency=None,
cv_source=None,
cv_mb_size_schedule=None,
cv_frequency=None,
max_training_samples=None):
'''
Creates a basic training session.
A factory function to create a training session object.
Args:
training_minibatch_source: a minibatch source that will be used for training.
trainer: a Trainer.
mb_size_schedule: a minibatch size schedule for training. Created using :func:`minibatch_size_schedule`
progress_printer: a progress printer instance
model_inputs_to_mb_source_mapping: mapping between the input node names of the model and the stream
names provided from the minibatch source. By default all streams are taken with their respective names.
checkpoint_filename: a file name of the checkpoint file, if None, the checkpointing is disabled.
checkpoint_frequency: an approximate number of global samples processed accross the workers
after which the checkpoint is taken. Should be positive number if the checkpoint file is specified.
save_all_checkpoints: flag, indicating whether to store all checkpoints, by default only the last checkpoint is preserved
restore: flag, indicating whether perform restore of the training session from the checkpoint before the start of the training
progress_frequency: an approximate number of global samples processed accross the workers
after which the summary of metrics is reported using the progress_printer
cv_source: a minibatch source that will be used for cross validation.
cv_mb_size_schedule: a minibatch size schedule for cross validation. Created using :func:`minibatch_size_schedule`
progress_frequency: an approximate number of global samples processed accross the workers
after which the cross validation takes place
max_training_samples: max number of samples after which the training should be stopped
Args:
training_minibatch_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for training
trainer (:class:`~cntk.trainer.Trainer`): trainer
mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for training
progress_printer (:class:`~cntk.utils.progress_print.ProgressPrinter`): progress printer
model_inputs_to_mb_source_mapping (dict): mapping between input variables and input streams
checkpoint_filename (str): checkpoint file name.
checkpoint_frequency (int): checkpoint frequency in samples. If 0, no checkpointing takes place.
If ``sys.maxsize``, a single checkpoint is taken at the end of the training.
save_all_checkpoints (bool): saves all checkpoints, using ``checkpoint_filename`` as prefix and checkpoint index as a suffix.
restore (bool): flag, indicating whether to restore from available checkpoint before the start of the training
progress_frequency (int): frequency in samples for aggregated progress printing
cv_source (:class:`~cntk.io.MinibatchSource`): minibatch source used for cross validation
cv_frequency (int): frequency in samples for cross validation
cv_mb_size_schedule (:class:`~cntk.cntk_py.minibatch_size_schedule`): minibatch schedule for cross validation
If ``sys.maxsize``, a single cross validation is performed at the end of training.
max_training_samples (int): maximum number of samples used for training
Returns:
Instance of a :class:`TrainingSession`
Instance of :class:`~TrainingSession`
'''
if not isinstance(mb_size_schedule, cntk_py.minibatch_size_schedule):
raise ValueError('mb_size_schedule type (%s) not supported. '
'mb_size_schedule must be a schedule '
'(output of minibatch_size_schedule() function)'
% type(mb_size_schedule))
if checkpoint_filename is None:
if checkpoint_frequency is not None and checkpoint_frequency != 0:
raise ValueError("Checkpoint frequency cannot be specified without checkpoint_filename")
checkpoint_frequency = 0
checkpoint_filename=""
if progress_frequency is None:
progress_frequency = sys.maxsize
if cv_source is None:
if cv_frequency is not None and cv_frequency != 0:
raise ValueError("Cross validation frequency cannot be specified without cross validation minibatch source")
cv_frequency = 0
if cv_frequency is None:
cv_frequency = sys.maxsize
if max_training_samples is None:
max_training_samples = sys.maxsize
if checkpoint_frequency is None:
checkpoint_frequency = sys.maxsize
if cv_mb_size_schedule is None:
cv_mb_size_schedule = minibatch_size_schedule(1)
return TrainingSession(training_minibatch_source, trainer,
mb_size_schedule, progress_printer,
model_inputs_to_mb_source_mapping,
return TrainingSession(training_minibatch_source, trainer,
mb_size_schedule, progress_printer,
model_inputs_to_mb_source_mapping,
checkpoint_frequency,
checkpoint_filename,
save_all_checkpoints=save_all_checkpoints,
restore=restore,
progress_frequency=progress_frequency,
cv_source=cv_source,
cv_mb_size_schedule=cv_mb_size_schedule,
cv_frequency=cv_frequency,
cv_mb_size_schedule=cv_mb_size_schedule,
max_training_samples=max_training_samples)

Просмотреть файл

@ -6,7 +6,9 @@ Python API Reference
Graph components <graph>
IO <cntk.io>
Trainer & learners <cntk.trainer>
Learner <cntk.learner>
Trainer <cntk.trainer>
Training session <cntk.training_session>
Operators <cntk.ops>
Utils <cntk.utils>
Module reference <modules>

Просмотреть файл

@ -0,0 +1,7 @@
cntk.training_session package
=============================
.. automodule:: cntk.training_session
:members:
:undoc-members:
:show-inheritance: