Merge pull request #1229 from mozilla/export-cleanup

Clean up export code and remove some TF 1.0 compat code (Fixes #1228)
This commit is contained in:
Reuben Morais 2018-02-14 14:06:16 -02:00 коммит произвёл GitHub
Родитель 5d3abe8948 983e13f218
Коммит 32c06acdf8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 21 добавлений и 44 удалений

Просмотреть файл

@ -11,6 +11,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_inde
import datetime
import pickle
import shutil
import six
import subprocess
import tensorflow as tf
import time
@ -18,7 +19,6 @@ import traceback
import inspect
from six.moves import zip, range, filter, urllib, BaseHTTPServer
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python.tools import freeze_graph
from threading import Thread, Lock
from util.audio import audiofile_to_input_vector
@ -432,18 +432,14 @@ def BiRNN(batch_x, seq_length, dropout):
# Now we create the forward and backward LSTM units.
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
# Forward direction cell: (if else required for TF 1.0 and 1.1 compat)
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \
if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \
tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
# Forward direction cell:
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell,
input_keep_prob=1.0 - dropout[3],
output_keep_prob=1.0 - dropout[3],
seed=FLAGS.random_seed)
# Backward direction cell: (if else required for TF 1.0 and 1.1 compat)
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \
if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \
tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
# Backward direction cell:
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell,
input_keep_prob=1.0 - dropout[4],
output_keep_prob=1.0 - dropout[4],
@ -1685,7 +1681,7 @@ def train(server=None):
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
sys.exit(1)
def create_inference_graph(batch_size=None, output_is_logits=False, use_new_decoder=False):
def create_inference_graph(batch_size=None, use_new_decoder=False):
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
input_tensor = tf.placeholder(tf.float32, [batch_size, None, n_input + 2*n_input*n_context], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
@ -1693,9 +1689,6 @@ def create_inference_graph(batch_size=None, output_is_logits=False, use_new_deco
# Calculate the logits of the batch using BiRNN
logits = BiRNN(input_tensor, tf.to_int64(seq_length) if FLAGS.use_seq_length else None, no_dropout)
if output_is_logits:
return logits
# Beam search decode the batch
decoder = decode_with_lm if use_new_decoder else tf.nn.ctc_beam_search_decoder
@ -1730,48 +1723,32 @@ def export():
# Create a saver and exporter using variables from the above newly created graph
saver = tf.train.Saver(tf.global_variables())
model_exporter = exporter.Exporter(saver)
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
log_info('Restored checkpoint at training epoch %d' % (int(checkpoint_path.split('-')[-1]) + 1))
# Initialise the model exporter and export the model
model_exporter.init(session.graph.as_graph_def(),
named_graph_signatures = {
'inputs': exporter.generic_signature(inputs),
'outputs': exporter.generic_signature(outputs)
})
if FLAGS.remove_export:
actual_export_dir = os.path.join(FLAGS.export_dir, '%08d' % FLAGS.export_version)
if os.path.isdir(actual_export_dir):
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(actual_export_dir)
shutil.rmtree(FLAGS.export_dir)
try:
# Export serving model
model_exporter.export(FLAGS.export_dir, tf.constant(FLAGS.export_version), session)
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')
# Export graph
input_graph_name = 'input_graph.pb'
tf.train.write_graph(session.graph, FLAGS.export_dir, input_graph_name, as_text=False)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
# Freeze graph
input_graph_path = os.path.join(FLAGS.export_dir, input_graph_name)
input_saver_def_path = ''
input_binary = True
output_node_names = 'output_node'
restore_op_name = 'save/restore_all'
filename_tensor_name = 'save/Const:0'
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_path, clear_devices, '')
freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=','.join(node.op.name for node in six.itervalues(outputs)),
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes='')
log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e: