diff --git a/DeepSpeech.py b/DeepSpeech.py index c2a22621..cf944aee 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -11,6 +11,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_inde import datetime import pickle import shutil +import six import subprocess import tensorflow as tf import time @@ -18,7 +19,6 @@ import traceback import inspect from six.moves import zip, range, filter, urllib, BaseHTTPServer -from tensorflow.contrib.session_bundle import exporter from tensorflow.python.tools import freeze_graph from threading import Thread, Lock from util.audio import audiofile_to_input_vector @@ -432,18 +432,14 @@ def BiRNN(batch_x, seq_length, dropout): # Now we create the forward and backward LSTM units. # Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM. - # Forward direction cell: (if else required for TF 1.0 and 1.1 compat) - lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \ - if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \ - tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse) + # Forward direction cell: + lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse) lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, input_keep_prob=1.0 - dropout[3], output_keep_prob=1.0 - dropout[3], seed=FLAGS.random_seed) - # Backward direction cell: (if else required for TF 1.0 and 1.1 compat) - lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \ - if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \ - tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse) + # Backward direction cell: + lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse) lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, input_keep_prob=1.0 - dropout[4], output_keep_prob=1.0 - dropout[4], @@ -1685,7 +1681,7 @@ def train(server=None): ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) -def create_inference_graph(batch_size=None, output_is_logits=False, use_new_decoder=False): +def create_inference_graph(batch_size=None, use_new_decoder=False): # Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context] input_tensor = tf.placeholder(tf.float32, [batch_size, None, n_input + 2*n_input*n_context], name='input_node') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') @@ -1693,9 +1689,6 @@ def create_inference_graph(batch_size=None, output_is_logits=False, use_new_deco # Calculate the logits of the batch using BiRNN logits = BiRNN(input_tensor, tf.to_int64(seq_length) if FLAGS.use_seq_length else None, no_dropout) - if output_is_logits: - return logits - # Beam search decode the batch decoder = decode_with_lm if use_new_decoder else tf.nn.ctc_beam_search_decoder @@ -1730,48 +1723,32 @@ def export(): # Create a saver and exporter using variables from the above newly created graph saver = tf.train.Saver(tf.global_variables()) - model_exporter = exporter.Exporter(saver) # Restore variables from training checkpoint - # TODO: This restores the most recent checkpoint, but if we use validation to counterract - # over-fitting, we may want to restore an earlier checkpoint. checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint_path = checkpoint.model_checkpoint_path - saver.restore(session, checkpoint_path) - log_info('Restored checkpoint at training epoch %d' % (int(checkpoint_path.split('-')[-1]) + 1)) - # Initialise the model exporter and export the model - model_exporter.init(session.graph.as_graph_def(), - named_graph_signatures = { - 'inputs': exporter.generic_signature(inputs), - 'outputs': exporter.generic_signature(outputs) - }) if FLAGS.remove_export: - actual_export_dir = os.path.join(FLAGS.export_dir, '%08d' % FLAGS.export_version) - if os.path.isdir(actual_export_dir): + if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') - shutil.rmtree(actual_export_dir) + shutil.rmtree(FLAGS.export_dir) try: - # Export serving model - model_exporter.export(FLAGS.export_dir, tf.constant(FLAGS.export_version), session) + output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb') - # Export graph - input_graph_name = 'input_graph.pb' - tf.train.write_graph(session.graph, FLAGS.export_dir, input_graph_name, as_text=False) + if not os.path.isdir(FLAGS.export_dir): + os.makedirs(FLAGS.export_dir) # Freeze graph - input_graph_path = os.path.join(FLAGS.export_dir, input_graph_name) - input_saver_def_path = '' - input_binary = True - output_node_names = 'output_node' - restore_op_name = 'save/restore_all' - filename_tensor_name = 'save/Const:0' - output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb') - clear_devices = False - freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, - input_binary, checkpoint_path, output_node_names, - restore_op_name, filename_tensor_name, - output_graph_path, clear_devices, '') + freeze_graph.freeze_graph_with_def_protos( + input_graph_def=session.graph_def, + input_saver_def=saver.as_saver_def(), + input_checkpoint=checkpoint_path, + output_node_names=','.join(node.op.name for node in six.itervalues(outputs)), + restore_op_name=None, + filename_tensor_name=None, + output_graph=output_graph_path, + clear_devices=False, + initializer_nodes='') log_info('Models exported at %s' % (FLAGS.export_dir)) except RuntimeError as e: