зеркало из https://github.com/mozilla/DeepSpeech.git
Merge pull request #1229 from mozilla/export-cleanup
Clean up export code and remove some TF 1.0 compat code (Fixes #1228)
This commit is contained in:
Коммит
32c06acdf8
|
@ -11,6 +11,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_inde
|
|||
import datetime
|
||||
import pickle
|
||||
import shutil
|
||||
import six
|
||||
import subprocess
|
||||
import tensorflow as tf
|
||||
import time
|
||||
|
@ -18,7 +19,6 @@ import traceback
|
|||
import inspect
|
||||
|
||||
from six.moves import zip, range, filter, urllib, BaseHTTPServer
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
from threading import Thread, Lock
|
||||
from util.audio import audiofile_to_input_vector
|
||||
|
@ -432,18 +432,14 @@ def BiRNN(batch_x, seq_length, dropout):
|
|||
# Now we create the forward and backward LSTM units.
|
||||
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
|
||||
|
||||
# Forward direction cell: (if else required for TF 1.0 and 1.1 compat)
|
||||
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \
|
||||
if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \
|
||||
tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
|
||||
# Forward direction cell:
|
||||
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
|
||||
lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell,
|
||||
input_keep_prob=1.0 - dropout[3],
|
||||
output_keep_prob=1.0 - dropout[3],
|
||||
seed=FLAGS.random_seed)
|
||||
# Backward direction cell: (if else required for TF 1.0 and 1.1 compat)
|
||||
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \
|
||||
if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \
|
||||
tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
|
||||
# Backward direction cell:
|
||||
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
|
||||
lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell,
|
||||
input_keep_prob=1.0 - dropout[4],
|
||||
output_keep_prob=1.0 - dropout[4],
|
||||
|
@ -1685,7 +1681,7 @@ def train(server=None):
|
|||
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
def create_inference_graph(batch_size=None, output_is_logits=False, use_new_decoder=False):
|
||||
def create_inference_graph(batch_size=None, use_new_decoder=False):
|
||||
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
input_tensor = tf.placeholder(tf.float32, [batch_size, None, n_input + 2*n_input*n_context], name='input_node')
|
||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
@ -1693,9 +1689,6 @@ def create_inference_graph(batch_size=None, output_is_logits=False, use_new_deco
|
|||
# Calculate the logits of the batch using BiRNN
|
||||
logits = BiRNN(input_tensor, tf.to_int64(seq_length) if FLAGS.use_seq_length else None, no_dropout)
|
||||
|
||||
if output_is_logits:
|
||||
return logits
|
||||
|
||||
# Beam search decode the batch
|
||||
decoder = decode_with_lm if use_new_decoder else tf.nn.ctc_beam_search_decoder
|
||||
|
||||
|
@ -1730,48 +1723,32 @@ def export():
|
|||
|
||||
# Create a saver and exporter using variables from the above newly created graph
|
||||
saver = tf.train.Saver(tf.global_variables())
|
||||
model_exporter = exporter.Exporter(saver)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
|
||||
# over-fitting, we may want to restore an earlier checkpoint.
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
log_info('Restored checkpoint at training epoch %d' % (int(checkpoint_path.split('-')[-1]) + 1))
|
||||
|
||||
# Initialise the model exporter and export the model
|
||||
model_exporter.init(session.graph.as_graph_def(),
|
||||
named_graph_signatures = {
|
||||
'inputs': exporter.generic_signature(inputs),
|
||||
'outputs': exporter.generic_signature(outputs)
|
||||
})
|
||||
if FLAGS.remove_export:
|
||||
actual_export_dir = os.path.join(FLAGS.export_dir, '%08d' % FLAGS.export_version)
|
||||
if os.path.isdir(actual_export_dir):
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(actual_export_dir)
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
try:
|
||||
# Export serving model
|
||||
model_exporter.export(FLAGS.export_dir, tf.constant(FLAGS.export_version), session)
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')
|
||||
|
||||
# Export graph
|
||||
input_graph_name = 'input_graph.pb'
|
||||
tf.train.write_graph(session.graph, FLAGS.export_dir, input_graph_name, as_text=False)
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
# Freeze graph
|
||||
input_graph_path = os.path.join(FLAGS.export_dir, input_graph_name)
|
||||
input_saver_def_path = ''
|
||||
input_binary = True
|
||||
output_node_names = 'output_node'
|
||||
restore_op_name = 'save/restore_all'
|
||||
filename_tensor_name = 'save/Const:0'
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')
|
||||
clear_devices = False
|
||||
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
|
||||
input_binary, checkpoint_path, output_node_names,
|
||||
restore_op_name, filename_tensor_name,
|
||||
output_graph_path, clear_devices, '')
|
||||
freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=session.graph_def,
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=','.join(node.op.name for node in six.itervalues(outputs)),
|
||||
restore_op_name=None,
|
||||
filename_tensor_name=None,
|
||||
output_graph=output_graph_path,
|
||||
clear_devices=False,
|
||||
initializer_nodes='')
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
except RuntimeError as e:
|
||||
|
|
Загрузка…
Ссылка в новой задаче