зеркало из https://github.com/mozilla/DeepSpeech.git
implement distributed training using horovod
This commit is contained in:
Родитель
7b2eeb6734
Коммит
11edd92775
|
@ -196,6 +196,21 @@ python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_fil
|
|||
|
||||
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.
|
||||
|
||||
Distributed training using Horovod
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If you have a capable compute architecture, we offer the opportunity to distribute the training using `Horovod <https://github.com/horovod/horovod>`_. A fast network is recommended.
|
||||
Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication.
|
||||
It also offers Gloo as an easy-to-setup communication backend.
|
||||
|
||||
For more information about setup or tuning of Horovod please visit `Horovod's Github <https://github.com/horovod/horovod>`_.
|
||||
|
||||
To train on 4 machines using 4 GPUs each:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod
|
||||
|
||||
Checkpointing
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
|
|
10
setup.py
10
setup.py
|
@ -76,6 +76,10 @@ def main():
|
|||
'tensorflow == 1.15.4'
|
||||
]
|
||||
|
||||
horovod_pypi_dep = [
|
||||
'horovod'
|
||||
]
|
||||
|
||||
# Due to pip craziness environment variables are the only consistent way to
|
||||
# get options into this script when doing `pip install`.
|
||||
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
|
||||
|
@ -94,6 +98,12 @@ def main():
|
|||
else:
|
||||
install_requires = install_requires + tensorflow_pypi_dep
|
||||
|
||||
if os.environ.get('DS_NOHOROVOD', ''):
|
||||
install_requires = install_requires
|
||||
else:
|
||||
install_requires = install_requires + horovod_pypi_dep
|
||||
|
||||
|
||||
setup(
|
||||
name='deepspeech_training',
|
||||
version=version,
|
||||
|
|
|
@ -424,7 +424,8 @@ def train():
|
|||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
reverse=FLAGS.reverse_train,
|
||||
limit=FLAGS.limit_train,
|
||||
buffering=FLAGS.read_buffer)
|
||||
buffering=FLAGS.read_buffer,
|
||||
split_dataset=False)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
|
@ -442,7 +443,8 @@ def train():
|
|||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
buffering=FLAGS.read_buffer,
|
||||
split_dataset=False) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
if FLAGS.metrics_files:
|
||||
|
@ -454,7 +456,8 @@ def train():
|
|||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer) for source in metrics_sources]
|
||||
buffering=FLAGS.read_buffer,
|
||||
split_dataset=False) for source in metrics_sources]
|
||||
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
||||
|
||||
# Dropout
|
||||
|
@ -677,6 +680,303 @@ def train():
|
|||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
def train_with_horovod():
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
epochs=FLAGS.epochs,
|
||||
augmentations=Config.augmentations,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
|
||||
reverse=FLAGS.reverse_train,
|
||||
limit=FLAGS.limit_train,
|
||||
buffering=FLAGS.read_buffer,
|
||||
split_dataset=True)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer,
|
||||
split_dataset=True) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
if FLAGS.metrics_files:
|
||||
metrics_sources = FLAGS.metrics_files.split(',')
|
||||
metrics_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer,
|
||||
split_dataset=True) for source in metrics_sources]
|
||||
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
|
||||
# Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
|
||||
optimizer = create_optimizer(learning_rate_var * hvd.size())
|
||||
optimizer = hvd.DistributedOptimizer(optimizer)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False)
|
||||
gradients = optimizer.compute_gradients(loss)
|
||||
|
||||
tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries'])
|
||||
log_grads_and_vars(gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120),
|
||||
'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120),
|
||||
}
|
||||
|
||||
human_readable_set_names = {
|
||||
'train': 'Training',
|
||||
'dev': 'Validation',
|
||||
'metrics': 'Metrics',
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
if Config.is_master_process:
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
if not is_remote_path(FLAGS.save_checkpoint_dir):
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open_remote(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
bcast = hvd.broadcast_global_variables(0)
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
load_or_init_graph_for_training(session)
|
||||
bcast.run()
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache:
|
||||
feature_cache_index = FLAGS.feature_cache + '.index'
|
||||
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
|
||||
log_info('Invalidating feature cache')
|
||||
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
if Config.is_master_process:
|
||||
# TODO endl seems not to work with horovod
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
if Config.is_master_process:
|
||||
pbar.update(step_count)
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
if Config.is_master_process:
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
if Config.is_master_process:
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
if Config.is_master_process:
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
if Config.is_master_process:
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
if Config.is_master_process:
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
|
||||
latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
if Config.is_master_process:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
if (
|
||||
FLAGS.reduce_lr_on_plateau
|
||||
and epochs_without_improvement > 0
|
||||
and epochs_without_improvement % FLAGS.plateau_epochs == 0
|
||||
):
|
||||
# Reload checkpoint that we use the best_dev weights again
|
||||
reload_best_checkpoint(session)
|
||||
|
||||
# Reduce learning rate
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
if Config.is_master_process:
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
# Overwrite best checkpoint with new learning rate value
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
|
||||
latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
|
||||
|
||||
if FLAGS.metrics_files:
|
||||
# Read only metrics, not affecting best validation loss tracking
|
||||
for source, init_op in zip(metrics_sources, metrics_init_ops):
|
||||
if Config.is_master_process:
|
||||
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
|
||||
set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
|
||||
if Config.is_master_process:
|
||||
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
if Config.is_master_process:
|
||||
print('-' * 80)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
if Config.is_master_process:
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
if Config.is_master_process:
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
@ -951,30 +1251,35 @@ def main(_):
|
|||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
if FLAGS.horovod:
|
||||
train_with_horovod()
|
||||
else:
|
||||
train()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
if Config.is_master_process:
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if listdir_remote(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
if listdir_remote(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
def run_script():
|
||||
|
|
|
@ -79,12 +79,33 @@ def initialize_globals():
|
|||
# CPU device
|
||||
c.cpu_device = '/cpu:0'
|
||||
|
||||
# Available GPU devices
|
||||
c.available_devices = get_available_gpus(c.session_config)
|
||||
if FLAGS.horovod:
|
||||
try:
|
||||
import horovod.tensorflow as hvd
|
||||
except ImportError as e:
|
||||
print(
|
||||
"Error importing Horovod. Did you installed DeepSpeech with -DNOHOROVOD? "
|
||||
"If you do not want to use horovod, use 'from deepspeech_training import train'")
|
||||
raise e
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if not c.available_devices:
|
||||
c.available_devices = [c.cpu_device]
|
||||
hvd.init()
|
||||
|
||||
# Pin GPU to be used to process local rank (one GPU per process)
|
||||
c.session_config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
c.num_devices = hvd.size()
|
||||
c.is_master_process = True if hvd.rank() == 0 else False
|
||||
else:
|
||||
# # Available GPU devices
|
||||
c.available_devices = get_available_gpus(c.session_config)
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if not c.available_devices:
|
||||
c.available_devices = [c.cpu_device]
|
||||
|
||||
c.num_devices = len(c.available_devices)
|
||||
|
||||
# If there are no horovod processes the only one should handled like horovod master
|
||||
c.is_master_process = True
|
||||
|
||||
if FLAGS.bytes_output_mode:
|
||||
c.alphabet = UTF8Alphabet()
|
||||
|
|
|
@ -94,7 +94,8 @@ def create_dataset(sources,
|
|||
limit=0,
|
||||
exception_box=None,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE):
|
||||
buffering=1 * MEGABYTE,
|
||||
split_dataset=False):
|
||||
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
||||
|
||||
def generate_values():
|
||||
|
@ -135,17 +136,25 @@ def create_dataset(sources,
|
|||
|
||||
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)
|
||||
|
||||
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
|
||||
output_types=(tf.string, tf.float32, tf.int32,
|
||||
(tf.int64, tf.int32, tf.int64), tf.float64))
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
||||
dataset = tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
|
||||
output_types=(tf.string, tf.float32, tf.int32,
|
||||
(tf.int64, tf.int32, tf.int64), tf.float64))
|
||||
if split_dataset:
|
||||
# Using horovod Iterator.get_next() is not aware of different devices.
|
||||
# A.shard(n, i) will contain all elements of A whose index mod n = i.
|
||||
import horovod.tensorflow as hvd
|
||||
dataset = dataset.shard(hvd.size(), hvd.rank())
|
||||
dataset = dataset.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
if cache_path:
|
||||
dataset = dataset.cache(cache_path)
|
||||
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
|
||||
.prefetch(len(Config.available_devices)))
|
||||
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn))
|
||||
if split_dataset:
|
||||
#TODO is there a way to get a proper value?
|
||||
dataset = dataset.prefetch(2)
|
||||
else:
|
||||
dataset = dataset.prefetch(Config.num_devices)
|
||||
return dataset
|
||||
|
||||
|
||||
def split_audio_file(audio_path,
|
||||
audio_format=DEFAULT_FORMAT,
|
||||
batch_size=1,
|
||||
|
@ -178,5 +187,5 @@ def split_audio_file(audio_path,
|
|||
ods = create_batch_set(outlier_batch_size,
|
||||
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
|
||||
dataset = nds.concatenate(ods)
|
||||
dataset = dataset.prefetch(len(Config.available_devices))
|
||||
dataset = dataset.prefetch(Config.num_devices)
|
||||
return dataset
|
||||
|
|
|
@ -69,6 +69,8 @@ def create_flags():
|
|||
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
|
||||
|
||||
f.DEFINE_boolean('horovod', False, 'use horovod for training on multiple gpus')
|
||||
|
||||
# Sample limits
|
||||
|
||||
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||
|
|
Загрузка…
Ссылка в новой задаче