implement distributed training using horovod

This commit is contained in:
NanoNabla 2021-02-16 12:37:06 +01:00
Родитель 7b2eeb6734
Коммит 11edd92775
6 изменённых файлов: 397 добавлений и 35 удалений

Просмотреть файл

@ -196,6 +196,21 @@ python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_fil
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.
Distributed training using Horovod
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you have a capable compute architecture, we offer the opportunity to distribute the training using `Horovod <https://github.com/horovod/horovod>`_. A fast network is recommended.
Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication.
It also offers Gloo as an easy-to-setup communication backend.
For more information about setup or tuning of Horovod please visit `Horovod's Github <https://github.com/horovod/horovod>`_.
To train on 4 machines using 4 GPUs each:
.. code-block:: bash
horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod
Checkpointing
^^^^^^^^^^^^^

Просмотреть файл

@ -76,6 +76,10 @@ def main():
'tensorflow == 1.15.4'
]
horovod_pypi_dep = [
'horovod'
]
# Due to pip craziness environment variables are the only consistent way to
# get options into this script when doing `pip install`.
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
@ -94,6 +98,12 @@ def main():
else:
install_requires = install_requires + tensorflow_pypi_dep
if os.environ.get('DS_NOHOROVOD', ''):
install_requires = install_requires
else:
install_requires = install_requires + horovod_pypi_dep
setup(
name='deepspeech_training',
version=version,

Просмотреть файл

@ -424,7 +424,8 @@ def train():
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
reverse=FLAGS.reverse_train,
limit=FLAGS.limit_train,
buffering=FLAGS.read_buffer)
buffering=FLAGS.read_buffer,
split_dataset=False)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
@ -442,7 +443,8 @@ def train():
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in dev_sources]
buffering=FLAGS.read_buffer,
split_dataset=False) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
if FLAGS.metrics_files:
@ -454,7 +456,8 @@ def train():
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in metrics_sources]
buffering=FLAGS.read_buffer,
split_dataset=False) for source in metrics_sources]
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
# Dropout
@ -677,6 +680,303 @@ def train():
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def train_with_horovod():
import horovod.tensorflow as hvd
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
epochs=FLAGS.epochs,
augmentations=Config.augmentations,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
reverse=FLAGS.reverse_train,
limit=FLAGS.limit_train,
buffering=FLAGS.read_buffer,
split_dataset=True)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer,
split_dataset=True) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
if FLAGS.metrics_files:
metrics_sources = FLAGS.metrics_files.split(',')
metrics_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer,
split_dataset=True) for source in metrics_sources]
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
# Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
optimizer = create_optimizer(learning_rate_var * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False)
gradients = optimizer.compute_gradients(loss)
tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries'])
log_grads_and_vars(gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120),
'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120),
}
human_readable_set_names = {
'train': 'Training',
'dev': 'Validation',
'metrics': 'Metrics',
}
# Checkpointing
if Config.is_master_process:
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
bcast = hvd.broadcast_global_variables(0)
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
load_or_init_graph_for_training(session)
bcast.run()
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache:
feature_cache_index = FLAGS.feature_cache + '.index'
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
log_info('Invalidating feature cache')
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
if Config.is_master_process:
# TODO endl seems not to work with horovod
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
if Config.is_master_process:
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
if Config.is_master_process:
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
if Config.is_master_process:
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
if Config.is_master_process:
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
if Config.is_master_process:
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
if Config.is_master_process:
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
if Config.is_master_process:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
if (
FLAGS.reduce_lr_on_plateau
and epochs_without_improvement > 0
and epochs_without_improvement % FLAGS.plateau_epochs == 0
):
# Reload checkpoint that we use the best_dev weights again
reload_best_checkpoint(session)
# Reduce learning rate
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
if Config.is_master_process:
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
# Overwrite best checkpoint with new learning rate value
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
latest_filename='best_dev_checkpoint')
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking
for source, init_op in zip(metrics_sources, metrics_init_ops):
if Config.is_master_process:
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
if Config.is_master_process:
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
if Config.is_master_process:
print('-' * 80)
except KeyboardInterrupt:
pass
if Config.is_master_process:
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
if Config.is_master_process:
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
@ -951,8 +1251,13 @@ def main(_):
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
if FLAGS.horovod:
train_with_horovod()
else:
train()
if Config.is_master_process:
if FLAGS.test_files:
tfv1.reset_default_graph()
test()

Просмотреть файл

@ -79,13 +79,34 @@ def initialize_globals():
# CPU device
c.cpu_device = '/cpu:0'
# Available GPU devices
if FLAGS.horovod:
try:
import horovod.tensorflow as hvd
except ImportError as e:
print(
"Error importing Horovod. Did you installed DeepSpeech with -DNOHOROVOD? "
"If you do not want to use horovod, use 'from deepspeech_training import train'")
raise e
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
c.session_config.gpu_options.visible_device_list = str(hvd.local_rank())
c.num_devices = hvd.size()
c.is_master_process = True if hvd.rank() == 0 else False
else:
# # Available GPU devices
c.available_devices = get_available_gpus(c.session_config)
# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]
c.num_devices = len(c.available_devices)
# If there are no horovod processes the only one should handled like horovod master
c.is_master_process = True
if FLAGS.bytes_output_mode:
c.alphabet = UTF8Alphabet()
else:

Просмотреть файл

@ -94,7 +94,8 @@ def create_dataset(sources,
limit=0,
exception_box=None,
process_ahead=None,
buffering=1 * MEGABYTE):
buffering=1 * MEGABYTE,
split_dataset=False):
epoch_counter = Counter() # survives restarts of the dataset and its generator
def generate_values():
@ -135,17 +136,25 @@ def create_dataset(sources,
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
dataset = tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
output_types=(tf.string, tf.float32, tf.int32,
(tf.int64, tf.int32, tf.int64), tf.float64))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
if split_dataset:
# Using horovod Iterator.get_next() is not aware of different devices.
# A.shard(n, i) will contain all elements of A whose index mod n = i.
import horovod.tensorflow as hvd
dataset = dataset.shard(hvd.size(), hvd.rank())
dataset = dataset.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if cache_path:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
.prefetch(len(Config.available_devices)))
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn))
if split_dataset:
#TODO is there a way to get a proper value?
dataset = dataset.prefetch(2)
else:
dataset = dataset.prefetch(Config.num_devices)
return dataset
def split_audio_file(audio_path,
audio_format=DEFAULT_FORMAT,
batch_size=1,
@ -178,5 +187,5 @@ def split_audio_file(audio_path,
ods = create_batch_set(outlier_batch_size,
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
dataset = nds.concatenate(ods)
dataset = dataset.prefetch(len(Config.available_devices))
dataset = dataset.prefetch(Config.num_devices)
return dataset

Просмотреть файл

@ -69,6 +69,8 @@ def create_flags():
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
f.DEFINE_boolean('horovod', False, 'use horovod for training on multiple gpus')
# Sample limits
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')