This commit is contained in:
Reuben Morais 2019-08-28 17:53:24 +02:00
Родитель 06dee673c7
Коммит 670e06365e
6 изменённых файлов: 43 добавлений и 32 удалений

Просмотреть файл

@ -44,7 +44,7 @@ def variable_on_cpu(name, shape, initializer):
def create_overlapping_windows(batch_x): def create_overlapping_windows(batch_x):
batch_size = tf.shape(batch_x)[0] batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1 window_width = 2 * Config.n_context + 1
num_channels = Config.n_input num_channels = Config.n_input
@ -55,7 +55,7 @@ def create_overlapping_windows(batch_x):
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation .reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows # Create overlapping windows
batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME') batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input] # Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels]) batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
@ -65,8 +65,8 @@ def create_overlapping_windows(batch_x):
def dense(name, x, units, dropout_rate=None, relu=True): def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name): with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tf.zeros_initializer()) bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tf.contrib.layers.xavier_initializer()) weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias) output = tf.nn.bias_add(tf.matmul(x, weights), bias)
@ -147,7 +147,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, pre
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size: if not batch_size:
batch_size = tf.shape(batch_x)[0] batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed # Create overlapping feature windows if needed
if overlap: if overlap:
@ -157,7 +157,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, pre
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`. # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size # Permute n_steps and batch_size
batch_x = tf.transpose(batch_x, [1, 0, 2, 3]) batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer # Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x layers['input_reshaped'] = batch_x
@ -232,7 +232,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss))) non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch # Calculate the average loss across the batch
avg_loss = tf.reduce_mean(total_loss) avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss # Finally we return the average loss
return avg_loss, non_finite_files return avg_loss, non_finite_files
@ -312,7 +312,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
tower_non_finite_files.append(non_finite_files) tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(tower_avg_losses, 0) avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries']) tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0) all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
@ -346,7 +346,7 @@ def average_gradients(tower_gradients):
# Average over the 'tower' dimension # Average over the 'tower' dimension
grad = tf.concat(grads, 0) grad = tf.concat(grads, 0)
grad = tf.reduce_mean(grad, 0) grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient # Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1]) grad_and_var = (grad, grad_and_vars[0][1])
@ -369,11 +369,11 @@ def log_variable(variable, gradient=None):
Furthermore it logs a histogram of its state and (if given) of an optimization gradient. Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
''' '''
name = variable.name.replace(':', '_') name = variable.name.replace(':', '_')
mean = tf.reduce_mean(variable) mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean) tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(tf.square(variable - mean)))) tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(variable)) tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(variable)) tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable) tfv1.summary.histogram(name=name, values=variable)
if gradient is not None: if gradient is not None:
if isinstance(gradient, tf.IndexedSlices): if isinstance(gradient, tf.IndexedSlices):
@ -667,7 +667,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer # One rate per layer
no_dropout = [None] * 6 no_dropout = [None] * 6

Просмотреть файл

@ -1,14 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import tensorflow as tf import tensorflow.compat.v1 as tfv1
import sys import sys
# Load and export as string from google.protobuf import text_format
with tf.gfile.FastGFile(sys.argv[1], 'rb') as fin:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fin.read())
with tf.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
from google.protobuf import text_format def main():
fout.write(text_format.MessageToString(graph_def)) # Load and export as string
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read())
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
fout.write(text_format.MessageToString(graph_def))
if __name__ == '__main__':
main()

Просмотреть файл

@ -1,11 +1,15 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import tensorflow as tf import tensorflow.compat.v1 as tfv1
import sys import sys
with tf.gfile.FastGFile(sys.argv[1], 'rb') as fin: def main():
graph_def = tf.GraphDef() with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
graph_def.ParseFromString(fin.read()) graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read())
print('\n'.join(sorted(set(n.op for n in graph_def.node)))) print('\n'.join(sorted(set(n.op for n in graph_def.node))))
if __name__ == '__main__':
main()

Просмотреть файл

@ -63,7 +63,7 @@ def evaluate(test_csvs, create_model, try_loading):
dropout=no_dropout) dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder # Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y, loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits, inputs=logits,

Просмотреть файл

@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function
import os import os
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict from attrdict import AttrDict
from xdg import BaseDirectory as xdg from xdg import BaseDirectory as xdg
@ -57,9 +58,9 @@ def initialize_globals():
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries')) FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries'))
# Standard session configuration that'll be used for all new sessions. # Standard session configuration that'll be used for all new sessions.
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads) intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path)) c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

Просмотреть файл

@ -39,7 +39,7 @@ def samples_to_mfccs(samples, sample_rate):
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input) mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
mfccs = tf.reshape(mfccs, [-1, Config.n_input]) mfccs = tf.reshape(mfccs, [-1, Config.n_input])
return mfccs, tf.shape(mfccs)[0] return mfccs, tf.shape(input=mfccs)[0]
def audiofile_to_features(wav_filename): def audiofile_to_features(wav_filename):