зеркало из https://github.com/mozilla/DeepSpeech.git
implement distributed training using horovod
This commit is contained in:
Родитель
7b2eeb6734
Коммит
11edd92775
|
@ -196,6 +196,21 @@ python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_fil
|
||||||
|
|
||||||
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.
|
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.
|
||||||
|
|
||||||
|
Distributed training using Horovod
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
If you have a capable compute architecture, we offer the opportunity to distribute the training using `Horovod <https://github.com/horovod/horovod>`_. A fast network is recommended.
|
||||||
|
Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication.
|
||||||
|
It also offers Gloo as an easy-to-setup communication backend.
|
||||||
|
|
||||||
|
For more information about setup or tuning of Horovod please visit `Horovod's Github <https://github.com/horovod/horovod>`_.
|
||||||
|
|
||||||
|
To train on 4 machines using 4 GPUs each:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod
|
||||||
|
|
||||||
Checkpointing
|
Checkpointing
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
10
setup.py
10
setup.py
|
@ -76,6 +76,10 @@ def main():
|
||||||
'tensorflow == 1.15.4'
|
'tensorflow == 1.15.4'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
horovod_pypi_dep = [
|
||||||
|
'horovod'
|
||||||
|
]
|
||||||
|
|
||||||
# Due to pip craziness environment variables are the only consistent way to
|
# Due to pip craziness environment variables are the only consistent way to
|
||||||
# get options into this script when doing `pip install`.
|
# get options into this script when doing `pip install`.
|
||||||
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
|
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
|
||||||
|
@ -94,6 +98,12 @@ def main():
|
||||||
else:
|
else:
|
||||||
install_requires = install_requires + tensorflow_pypi_dep
|
install_requires = install_requires + tensorflow_pypi_dep
|
||||||
|
|
||||||
|
if os.environ.get('DS_NOHOROVOD', ''):
|
||||||
|
install_requires = install_requires
|
||||||
|
else:
|
||||||
|
install_requires = install_requires + horovod_pypi_dep
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='deepspeech_training',
|
name='deepspeech_training',
|
||||||
version=version,
|
version=version,
|
||||||
|
|
|
@ -424,7 +424,8 @@ def train():
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||||
reverse=FLAGS.reverse_train,
|
reverse=FLAGS.reverse_train,
|
||||||
limit=FLAGS.limit_train,
|
limit=FLAGS.limit_train,
|
||||||
buffering=FLAGS.read_buffer)
|
buffering=FLAGS.read_buffer,
|
||||||
|
split_dataset=False)
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||||
tfv1.data.get_output_shapes(train_set),
|
tfv1.data.get_output_shapes(train_set),
|
||||||
|
@ -442,7 +443,8 @@ def train():
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||||
reverse=FLAGS.reverse_dev,
|
reverse=FLAGS.reverse_dev,
|
||||||
limit=FLAGS.limit_dev,
|
limit=FLAGS.limit_dev,
|
||||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
buffering=FLAGS.read_buffer,
|
||||||
|
split_dataset=False) for source in dev_sources]
|
||||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||||
|
|
||||||
if FLAGS.metrics_files:
|
if FLAGS.metrics_files:
|
||||||
|
@ -454,7 +456,8 @@ def train():
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||||
reverse=FLAGS.reverse_dev,
|
reverse=FLAGS.reverse_dev,
|
||||||
limit=FLAGS.limit_dev,
|
limit=FLAGS.limit_dev,
|
||||||
buffering=FLAGS.read_buffer) for source in metrics_sources]
|
buffering=FLAGS.read_buffer,
|
||||||
|
split_dataset=False) for source in metrics_sources]
|
||||||
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
||||||
|
|
||||||
# Dropout
|
# Dropout
|
||||||
|
@ -677,6 +680,303 @@ def train():
|
||||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||||
log_debug('Session closed.')
|
log_debug('Session closed.')
|
||||||
|
|
||||||
|
def train_with_horovod():
|
||||||
|
|
||||||
|
import horovod.tensorflow as hvd
|
||||||
|
|
||||||
|
exception_box = ExceptionBox()
|
||||||
|
|
||||||
|
# Create training and validation datasets
|
||||||
|
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||||
|
batch_size=FLAGS.train_batch_size,
|
||||||
|
epochs=FLAGS.epochs,
|
||||||
|
augmentations=Config.augmentations,
|
||||||
|
cache_path=FLAGS.feature_cache,
|
||||||
|
train_phase=True,
|
||||||
|
exception_box=exception_box,
|
||||||
|
process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
|
||||||
|
reverse=FLAGS.reverse_train,
|
||||||
|
limit=FLAGS.limit_train,
|
||||||
|
buffering=FLAGS.read_buffer,
|
||||||
|
split_dataset=True)
|
||||||
|
|
||||||
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||||
|
tfv1.data.get_output_shapes(train_set),
|
||||||
|
output_classes=tfv1.data.get_output_classes(train_set))
|
||||||
|
|
||||||
|
# Make initialization ops for switching between the two sets
|
||||||
|
train_init_op = iterator.make_initializer(train_set)
|
||||||
|
|
||||||
|
if FLAGS.dev_files:
|
||||||
|
dev_sources = FLAGS.dev_files.split(',')
|
||||||
|
dev_sets = [create_dataset([source],
|
||||||
|
batch_size=FLAGS.dev_batch_size,
|
||||||
|
train_phase=False,
|
||||||
|
exception_box=exception_box,
|
||||||
|
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
|
||||||
|
reverse=FLAGS.reverse_dev,
|
||||||
|
limit=FLAGS.limit_dev,
|
||||||
|
buffering=FLAGS.read_buffer,
|
||||||
|
split_dataset=True) for source in dev_sources]
|
||||||
|
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||||
|
|
||||||
|
if FLAGS.metrics_files:
|
||||||
|
metrics_sources = FLAGS.metrics_files.split(',')
|
||||||
|
metrics_sets = [create_dataset([source],
|
||||||
|
batch_size=FLAGS.dev_batch_size,
|
||||||
|
train_phase=False,
|
||||||
|
exception_box=exception_box,
|
||||||
|
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
|
||||||
|
reverse=FLAGS.reverse_dev,
|
||||||
|
limit=FLAGS.limit_dev,
|
||||||
|
buffering=FLAGS.read_buffer,
|
||||||
|
split_dataset=True) for source in metrics_sources]
|
||||||
|
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
||||||
|
|
||||||
|
# Dropout
|
||||||
|
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||||
|
dropout_feed_dict = {
|
||||||
|
dropout_rates[0]: FLAGS.dropout_rate,
|
||||||
|
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||||
|
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||||
|
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||||
|
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||||
|
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||||
|
}
|
||||||
|
no_dropout_feed_dict = {
|
||||||
|
rate: 0. for rate in dropout_rates
|
||||||
|
}
|
||||||
|
|
||||||
|
# Building the graph
|
||||||
|
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||||
|
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||||
|
|
||||||
|
# Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
|
||||||
|
optimizer = create_optimizer(learning_rate_var * hvd.size())
|
||||||
|
optimizer = hvd.DistributedOptimizer(optimizer)
|
||||||
|
|
||||||
|
# Enable mixed precision training
|
||||||
|
if FLAGS.automatic_mixed_precision:
|
||||||
|
log_info('Enabling automatic mixed precision training.')
|
||||||
|
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||||
|
|
||||||
|
loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False)
|
||||||
|
gradients = optimizer.compute_gradients(loss)
|
||||||
|
|
||||||
|
tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries'])
|
||||||
|
log_grads_and_vars(gradients)
|
||||||
|
|
||||||
|
# global_step is automagically incremented by the optimizer
|
||||||
|
global_step = tfv1.train.get_or_create_global_step()
|
||||||
|
apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step)
|
||||||
|
|
||||||
|
# Summaries
|
||||||
|
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||||
|
step_summary_writers = {
|
||||||
|
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||||
|
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120),
|
||||||
|
'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120),
|
||||||
|
}
|
||||||
|
|
||||||
|
human_readable_set_names = {
|
||||||
|
'train': 'Training',
|
||||||
|
'dev': 'Validation',
|
||||||
|
'metrics': 'Metrics',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
if Config.is_master_process:
|
||||||
|
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||||
|
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||||
|
|
||||||
|
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||||
|
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||||
|
|
||||||
|
# Save flags next to checkpoints
|
||||||
|
if not is_remote_path(FLAGS.save_checkpoint_dir):
|
||||||
|
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||||
|
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||||
|
with open_remote(flags_file, 'w') as fout:
|
||||||
|
fout.write(FLAGS.flags_into_string())
|
||||||
|
|
||||||
|
bcast = hvd.broadcast_global_variables(0)
|
||||||
|
|
||||||
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
|
log_debug('Session opened.')
|
||||||
|
|
||||||
|
# Prevent further graph changes
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
|
# Load checkpoint or initialize variables
|
||||||
|
load_or_init_graph_for_training(session)
|
||||||
|
bcast.run()
|
||||||
|
|
||||||
|
def run_set(set_name, epoch, init_op, dataset=None):
|
||||||
|
is_train = set_name == 'train'
|
||||||
|
train_op = apply_gradient_op if is_train else []
|
||||||
|
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
step_summary_writer = step_summary_writers.get(set_name)
|
||||||
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
|
if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache:
|
||||||
|
feature_cache_index = FLAGS.feature_cache + '.index'
|
||||||
|
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
|
||||||
|
log_info('Invalidating feature cache')
|
||||||
|
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||||
|
|
||||||
|
# Setup progress bar
|
||||||
|
class LossWidget(progressbar.widgets.FormatLabel):
|
||||||
|
def __init__(self):
|
||||||
|
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||||
|
|
||||||
|
def __call__(self, progress, data, **kwargs):
|
||||||
|
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||||
|
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||||
|
|
||||||
|
if Config.is_master_process:
|
||||||
|
# TODO endl seems not to work with horovod
|
||||||
|
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
|
||||||
|
widgets = [' | ', progressbar.widgets.Timer(),
|
||||||
|
' | Steps: ', progressbar.widgets.Counter(),
|
||||||
|
' | ', LossWidget()]
|
||||||
|
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||||
|
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||||
|
|
||||||
|
# Initialize iterator to the appropriate dataset
|
||||||
|
session.run(init_op)
|
||||||
|
|
||||||
|
# Batch loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_, current_step, batch_loss, problem_files, step_summary = \
|
||||||
|
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||||
|
feed_dict=feed_dict)
|
||||||
|
exception_box.raise_if_set()
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
exception_box.raise_if_set()
|
||||||
|
break
|
||||||
|
|
||||||
|
if problem_files.size > 0:
|
||||||
|
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||||
|
log_error('The following files caused an infinite (or NaN) '
|
||||||
|
'loss: {}'.format(','.join(problem_files)))
|
||||||
|
|
||||||
|
total_loss += batch_loss
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
if Config.is_master_process:
|
||||||
|
pbar.update(step_count)
|
||||||
|
step_summary_writer.add_summary(step_summary, current_step)
|
||||||
|
|
||||||
|
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||||
|
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||||
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
|
pbar.finish()
|
||||||
|
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||||
|
return mean_loss, step_count
|
||||||
|
|
||||||
|
log_info('STARTING Optimization')
|
||||||
|
train_start_time = datetime.utcnow()
|
||||||
|
best_dev_loss = float('inf')
|
||||||
|
dev_losses = []
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
try:
|
||||||
|
for epoch in range(FLAGS.epochs):
|
||||||
|
# Training
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_progress('Training epoch %d...' % epoch)
|
||||||
|
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||||
|
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||||
|
|
||||||
|
if FLAGS.dev_files:
|
||||||
|
# Validation
|
||||||
|
dev_loss = 0.0
|
||||||
|
total_steps = 0
|
||||||
|
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||||
|
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||||
|
dev_loss += set_loss * steps
|
||||||
|
total_steps += steps
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||||
|
|
||||||
|
dev_loss = dev_loss / total_steps
|
||||||
|
dev_losses.append(dev_loss)
|
||||||
|
|
||||||
|
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||||
|
# the improvement has to be greater than FLAGS.es_min_delta
|
||||||
|
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
else:
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
|
||||||
|
if Config.is_master_process:
|
||||||
|
# Save new best model
|
||||||
|
if dev_loss < best_dev_loss:
|
||||||
|
best_dev_loss = dev_loss
|
||||||
|
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
|
||||||
|
latest_filename='best_dev_checkpoint')
|
||||||
|
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||||
|
epochs_without_improvement))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reduce learning rate on plateau
|
||||||
|
# If the learning rate was reduced and there is still no improvement
|
||||||
|
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||||
|
if (
|
||||||
|
FLAGS.reduce_lr_on_plateau
|
||||||
|
and epochs_without_improvement > 0
|
||||||
|
and epochs_without_improvement % FLAGS.plateau_epochs == 0
|
||||||
|
):
|
||||||
|
# Reload checkpoint that we use the best_dev weights again
|
||||||
|
reload_best_checkpoint(session)
|
||||||
|
|
||||||
|
# Reduce learning rate
|
||||||
|
session.run(reduce_learning_rate_op)
|
||||||
|
current_learning_rate = learning_rate_var.eval()
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||||
|
current_learning_rate))
|
||||||
|
|
||||||
|
# Overwrite best checkpoint with new learning rate value
|
||||||
|
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
|
||||||
|
latest_filename='best_dev_checkpoint')
|
||||||
|
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
|
||||||
|
|
||||||
|
if FLAGS.metrics_files:
|
||||||
|
# Read only metrics, not affecting best validation loss tracking
|
||||||
|
for source, init_op in zip(metrics_sources, metrics_init_ops):
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
|
||||||
|
set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||||
|
|
||||||
|
if Config.is_master_process:
|
||||||
|
print('-' * 80)
|
||||||
|
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||||
|
if Config.is_master_process:
|
||||||
|
log_debug('Session closed.')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
|
@ -951,30 +1251,35 @@ def main(_):
|
||||||
if FLAGS.train_files:
|
if FLAGS.train_files:
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
tfv1.set_random_seed(FLAGS.random_seed)
|
tfv1.set_random_seed(FLAGS.random_seed)
|
||||||
train()
|
|
||||||
|
|
||||||
if FLAGS.test_files:
|
if FLAGS.horovod:
|
||||||
tfv1.reset_default_graph()
|
train_with_horovod()
|
||||||
test()
|
else:
|
||||||
|
train()
|
||||||
|
|
||||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
if Config.is_master_process:
|
||||||
tfv1.reset_default_graph()
|
if FLAGS.test_files:
|
||||||
export()
|
tfv1.reset_default_graph()
|
||||||
|
test()
|
||||||
|
|
||||||
if FLAGS.export_zip:
|
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
FLAGS.export_tflite = True
|
export()
|
||||||
|
|
||||||
if listdir_remote(FLAGS.export_dir):
|
if FLAGS.export_zip:
|
||||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
tfv1.reset_default_graph()
|
||||||
sys.exit(1)
|
FLAGS.export_tflite = True
|
||||||
|
|
||||||
export()
|
if listdir_remote(FLAGS.export_dir):
|
||||||
package_zip()
|
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
if FLAGS.one_shot_infer:
|
export()
|
||||||
tfv1.reset_default_graph()
|
package_zip()
|
||||||
do_single_file_inference(FLAGS.one_shot_infer)
|
|
||||||
|
if FLAGS.one_shot_infer:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
do_single_file_inference(FLAGS.one_shot_infer)
|
||||||
|
|
||||||
|
|
||||||
def run_script():
|
def run_script():
|
||||||
|
|
|
@ -79,12 +79,33 @@ def initialize_globals():
|
||||||
# CPU device
|
# CPU device
|
||||||
c.cpu_device = '/cpu:0'
|
c.cpu_device = '/cpu:0'
|
||||||
|
|
||||||
# Available GPU devices
|
if FLAGS.horovod:
|
||||||
c.available_devices = get_available_gpus(c.session_config)
|
try:
|
||||||
|
import horovod.tensorflow as hvd
|
||||||
|
except ImportError as e:
|
||||||
|
print(
|
||||||
|
"Error importing Horovod. Did you installed DeepSpeech with -DNOHOROVOD? "
|
||||||
|
"If you do not want to use horovod, use 'from deepspeech_training import train'")
|
||||||
|
raise e
|
||||||
|
|
||||||
# If there is no GPU available, we fall back to CPU based operation
|
hvd.init()
|
||||||
if not c.available_devices:
|
|
||||||
c.available_devices = [c.cpu_device]
|
# Pin GPU to be used to process local rank (one GPU per process)
|
||||||
|
c.session_config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||||
|
c.num_devices = hvd.size()
|
||||||
|
c.is_master_process = True if hvd.rank() == 0 else False
|
||||||
|
else:
|
||||||
|
# # Available GPU devices
|
||||||
|
c.available_devices = get_available_gpus(c.session_config)
|
||||||
|
|
||||||
|
# If there is no GPU available, we fall back to CPU based operation
|
||||||
|
if not c.available_devices:
|
||||||
|
c.available_devices = [c.cpu_device]
|
||||||
|
|
||||||
|
c.num_devices = len(c.available_devices)
|
||||||
|
|
||||||
|
# If there are no horovod processes the only one should handled like horovod master
|
||||||
|
c.is_master_process = True
|
||||||
|
|
||||||
if FLAGS.bytes_output_mode:
|
if FLAGS.bytes_output_mode:
|
||||||
c.alphabet = UTF8Alphabet()
|
c.alphabet = UTF8Alphabet()
|
||||||
|
|
|
@ -94,7 +94,8 @@ def create_dataset(sources,
|
||||||
limit=0,
|
limit=0,
|
||||||
exception_box=None,
|
exception_box=None,
|
||||||
process_ahead=None,
|
process_ahead=None,
|
||||||
buffering=1 * MEGABYTE):
|
buffering=1 * MEGABYTE,
|
||||||
|
split_dataset=False):
|
||||||
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
||||||
|
|
||||||
def generate_values():
|
def generate_values():
|
||||||
|
@ -135,17 +136,25 @@ def create_dataset(sources,
|
||||||
|
|
||||||
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)
|
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)
|
||||||
|
|
||||||
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
|
dataset = tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
|
||||||
output_types=(tf.string, tf.float32, tf.int32,
|
output_types=(tf.string, tf.float32, tf.int32,
|
||||||
(tf.int64, tf.int32, tf.int64), tf.float64))
|
(tf.int64, tf.int32, tf.int64), tf.float64))
|
||||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
if split_dataset:
|
||||||
|
# Using horovod Iterator.get_next() is not aware of different devices.
|
||||||
|
# A.shard(n, i) will contain all elements of A whose index mod n = i.
|
||||||
|
import horovod.tensorflow as hvd
|
||||||
|
dataset = dataset.shard(hvd.size(), hvd.rank())
|
||||||
|
dataset = dataset.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
if cache_path:
|
if cache_path:
|
||||||
dataset = dataset.cache(cache_path)
|
dataset = dataset.cache(cache_path)
|
||||||
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
|
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn))
|
||||||
.prefetch(len(Config.available_devices)))
|
if split_dataset:
|
||||||
|
#TODO is there a way to get a proper value?
|
||||||
|
dataset = dataset.prefetch(2)
|
||||||
|
else:
|
||||||
|
dataset = dataset.prefetch(Config.num_devices)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def split_audio_file(audio_path,
|
def split_audio_file(audio_path,
|
||||||
audio_format=DEFAULT_FORMAT,
|
audio_format=DEFAULT_FORMAT,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
|
@ -178,5 +187,5 @@ def split_audio_file(audio_path,
|
||||||
ods = create_batch_set(outlier_batch_size,
|
ods = create_batch_set(outlier_batch_size,
|
||||||
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
|
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
|
||||||
dataset = nds.concatenate(ods)
|
dataset = nds.concatenate(ods)
|
||||||
dataset = dataset.prefetch(len(Config.available_devices))
|
dataset = dataset.prefetch(Config.num_devices)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -69,6 +69,8 @@ def create_flags():
|
||||||
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||||
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
|
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
|
||||||
|
|
||||||
|
f.DEFINE_boolean('horovod', False, 'use horovod for training on multiple gpus')
|
||||||
|
|
||||||
# Sample limits
|
# Sample limits
|
||||||
|
|
||||||
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||||
|
|
Загрузка…
Ссылка в новой задаче