diff --git a/DeepSpeech.py b/DeepSpeech.py index 870250e8..f9c17fc3 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -446,10 +446,12 @@ def train(): # Building the graph optimizer = create_optimizer() - + # Enable mixed precision training if FLAGS.automatic_mixed_precision: - optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) + log_info('Enabling automatic mixed precision training.') + optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) + gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs diff --git a/util/flags.py b/util/flags.py index 2e585a7c..e5b3afdd 100644 --- a/util/flags.py +++ b/util/flags.py @@ -80,8 +80,8 @@ def create_flags(): f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.') - f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training') - + f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.') + # Sample limits f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')