Package training code to avoid sys.path hacks

This commit is contained in:
Reuben Morais 2020-03-25 17:07:29 +01:00
Родитель 58bc2f2bb1
Коммит a05baa35c9
74 изменённых файлов: 1677 добавлений и 1731 удалений

Просмотреть файл

@ -9,7 +9,7 @@ python:
jobs: jobs:
include: include:
- stage: cardboard linter - name: cardboard linter
install: install:
- pip install --upgrade cardboardlint pylint - pip install --upgrade cardboardlint pylint
script: script:
@ -17,9 +17,10 @@ jobs:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
cardboardlinter --refspec $TRAVIS_BRANCH -n auto; cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
fi fi
- stage: python unit tests - name: python unit tests
install: install:
- pip install --upgrade -r requirements_tests.txt - pip install --upgrade -r requirements_tests.txt;
pip install --upgrade .
script: script:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
python -m unittest; python -m unittest;

Просмотреть файл

@ -2,934 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version, ExceptionBox
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
else:
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
from tensorflow.python.framework.ops import Tensor, Operation
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__': if __name__ == '__main__':
create_flags() try:
absl.app.run(main) from deepspeech_training import train as ds_train
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_train.run_script()

Просмотреть файл

@ -150,7 +150,7 @@ COPY . /DeepSpeech/
WORKDIR /DeepSpeech WORKDIR /DeepSpeech
RUN pip3 --no-cache-dir install -r requirements.txt RUN pip3 --no-cache-dir install .
# Link DeepSpeech native_client libs to tf folder # Link DeepSpeech native_client libs to tf folder
RUN ln -s /DeepSpeech/native_client /tensorflow RUN ln -s /DeepSpeech/native_client /tensorflow

Просмотреть файл

@ -5,18 +5,12 @@ Use "python3 build_sdb.py -h" for help
''' '''
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse import argparse
import progressbar import progressbar
from util.downloader import SIMPLE_BAR from deepspeech_training.util.downloader import SIMPLE_BAR
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS from deepspeech_training.util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
from util.sample_collections import samples_from_files, DirectSDBWriter from deepspeech_training.util.sample_collections import samples_from_files, DirectSDBWriter
AUDIO_TYPE_LOOKUP = { AUDIO_TYPE_LOOKUP = {
'wav': AUDIO_TYPE_WAV, 'wav': AUDIO_TYPE_WAV,

Просмотреть файл

@ -1,11 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
from deepspeech_training.util.gpu_usage import GPUUsage
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsage
gu = GPUUsage() gu = GPUUsage()
gu.start() gu.start()

Просмотреть файл

@ -1,10 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys import sys
import os from deepspeech_training.util.gpu_usage import GPUUsageChart
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsageChart
GPUUsageChart(sys.argv[1], sys.argv[2]) GPUUsageChart(sys.argv[1], sys.argv[2])

Просмотреть файл

@ -1,17 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import os
import pandas import pandas
import tarfile import tarfile
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']

Просмотреть файл

@ -1,17 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import tarfile import os
import pandas import pandas
import tarfile
from deepspeech_training.util.importers import get_importers_parser
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']

Просмотреть файл

@ -1,23 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import sox import os
import tarfile
import subprocess
import progressbar import progressbar
import sox
import subprocess
import tarfile
from glob import glob from glob import glob
from os import path
from multiprocessing import Pool from multiprocessing import Pool
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report from deepspeech_training.util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from util.downloader import maybe_download, SIMPLE_BAR from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -28,7 +22,7 @@ ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' +
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data # Conditionally extract common voice data
@ -38,8 +32,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -47,9 +41,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % extracted_path) print('Found directory "%s" - not extracting it from archive.' % extracted_path)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
for source_csv in glob(path.join(extracted_dir, '*.csv')): for source_csv in glob(os.path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1])) _maybe_convert_set(extracted_dir, source_csv, os.path.join(target_dir, os.path.split(source_csv)[-1]))
def one_sample(sample): def one_sample(sample):
mp3_filename = sample[0] mp3_filename = sample[0]
@ -58,7 +52,7 @@ def one_sample(sample):
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
file_size = -1 file_size = -1
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1]) label = validate_label(sample[1])
@ -85,7 +79,7 @@ def one_sample(sample):
def _maybe_convert_set(extracted_dir, source_csv, target_csv): def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print() print()
if path.exists(target_csv): if os.path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv)) print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv)) print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
@ -126,7 +120,7 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:

Просмотреть файл

@ -8,23 +8,17 @@ Use "python3 import_cv2.py -h" for help
''' '''
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import os
import progressbar
import sox import sox
import subprocess import subprocess
import progressbar
import unicodedata import unicodedata
from os import path
from multiprocessing import Pool from multiprocessing import Pool
from util.downloader import SIMPLE_BAR from deepspeech_training.util.downloader import SIMPLE_BAR
from util.text import Alphabet from deepspeech_training.util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -34,7 +28,7 @@ MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']: for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv") input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset+".tsv")
if os.path.isfile(input_tsv): if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv) print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character) _maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
@ -42,15 +36,15 @@ def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0] mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3': if not os.path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3" mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav" wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter_fun(sample[1]) label = label_filter_fun(sample[1])
rows = [] rows = []
@ -76,7 +70,7 @@ def one_sample(sample):
return (counter, rows) return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None): def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv')) output_csv = os.path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
@ -84,7 +78,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
with open(input_tsv, encoding='utf-8') as input_tsv_file: with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t') reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader: for row in reader:
samples.append((path.join(audio_dir, row['path']), row['sentence'])) samples.append((os.path.join(audio_dir, row['path']), row['sentence']))
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
@ -120,7 +114,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:

Просмотреть файл

@ -4,22 +4,17 @@ from __future__ import absolute_import, division, print_function
# Prerequisite: Having the sph2pipe tool in your PATH: # Prerequisite: Having the sph2pipe tool in your PATH:
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools # https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs import codecs
import fnmatch import fnmatch
import librosa
import os import os
import pandas import pandas
import subprocess
import unicodedata
import librosa
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile
import subprocess
import sys
import unicodedata
from util.importers import validate_label_eng as validate_label from deepspeech_training.util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19

Просмотреть файл

@ -1,18 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import numpy as np import numpy as np
import os
import pandas import pandas
import tarfile import tarfile
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']

Просмотреть файл

@ -1,22 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import math
import urllib
import logging import logging
from util.importers import get_importers_parser, get_validate_label import math
import subprocess import os
from os import path
from pathlib import Path
import swifter
import pandas as pd import pandas as pd
import swifter
import subprocess
import urllib
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
from pathlib import Path
from sox import Transformer from sox import Transformer
@ -142,11 +136,11 @@ class GramVaaniDownloader:
return mp3_directory return mp3_directory
def _pre_download(self): def _pre_download(self):
mp3_directory = path.join(self.target_dir, "mp3") mp3_directory = os.path.join(self.target_dir, "mp3")
if not path.exists(self.target_dir): if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir) _logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir) os.mkdir(self.target_dir)
if not path.exists(mp3_directory): if not os.path.exists(mp3_directory):
_logger.info("Creating directory...%s", mp3_directory) _logger.info("Creating directory...%s", mp3_directory)
os.mkdir(mp3_directory) os.mkdir(mp3_directory)
return mp3_directory return mp3_directory
@ -154,8 +148,8 @@ class GramVaaniDownloader:
def _download(self, audio_url, transcript, audio_length, mp3_directory): def _download(self, audio_url, transcript, audio_length, mp3_directory):
if audio_url == "audio_url": if audio_url == "audio_url":
return return
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url)) mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
if not path.exists(mp3_filename): if not os.path.exists(mp3_filename):
_logger.debug("Downloading mp3 file...%s", audio_url) _logger.debug("Downloading mp3 file...%s", audio_url)
urllib.request.urlretrieve(audio_url, mp3_filename) urllib.request.urlretrieve(audio_url, mp3_filename)
else: else:
@ -182,8 +176,8 @@ class GramVaaniConverter:
""" """
wav_directory = self._pre_convert() wav_directory = self._pre_convert()
for mp3_filename in self.mp3_directory.glob('**/*.mp3'): for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav") wav_filename = os.path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename)) _logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
transformer = Transformer() transformer = Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH) transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
@ -193,11 +187,11 @@ class GramVaaniConverter:
return wav_directory return wav_directory
def _pre_convert(self): def _pre_convert(self):
wav_directory = path.join(self.target_dir, "wav") wav_directory = os.path.join(self.target_dir, "wav")
if not path.exists(self.target_dir): if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir) _logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir) os.mkdir(self.target_dir)
if not path.exists(wav_directory): if not os.path.exists(wav_directory):
_logger.info("Creating directory...%s", wav_directory) _logger.info("Creating directory...%s", wav_directory)
os.mkdir(wav_directory) os.mkdir(wav_directory)
return wav_directory return wav_directory
@ -233,8 +227,8 @@ class GramVaaniDataSets:
if audio_url == "audio_url": if audio_url == "audio_url":
return pd.Series(["wav_filename", "wav_filesize", "transcript"]) return pd.Series(["wav_filename", "wav_filesize", "transcript"])
mp3_filename = os.path.basename(audio_url) mp3_filename = os.path.basename(audio_url)
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav") wav_relative_filename = os.path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename)) wav_filesize = os.path.getsize(os.path.join(self.target_dir, wav_relative_filename))
transcript = validate_label(transcript) transcript = validate_label(transcript)
if None == transcript: if None == transcript:
transcript = "" transcript = ""
@ -252,7 +246,7 @@ class GramVaaniDataSets:
def _is_valid_raw_wav_frames(self): def _is_valid_raw_wav_frames(self):
transcripts = [str(transcript) for transcript in self.raw.transcript] transcripts = [str(transcript) for transcript in self.raw.transcript]
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename] wav_filepaths = [os.path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths] wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)] is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
return pd.Series(is_valid_raw_wav_frames) return pd.Series(is_valid_raw_wav_frames)

Просмотреть файл

@ -1,15 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import pandas import pandas
import os
import sys
from util.downloader import maybe_download from deepspeech_training.util.downloader import maybe_download
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data # Conditionally download data

Просмотреть файл

@ -1,22 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs import codecs
import fnmatch import fnmatch
import os
import pandas import pandas
import progressbar import progressbar
import subprocess import subprocess
import sys
import tarfile import tarfile
import unicodedata import unicodedata
from deepspeech_training.util.downloader import maybe_download
from sox import Transformer from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000

Просмотреть файл

@ -1,31 +1,23 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse import argparse
import csv import csv
import os
import progressbar
import re import re
import sox import sox
import zipfile
import subprocess import subprocess
import progressbar
import unicodedata import unicodedata
import zipfile
from multiprocessing import Pool from deepspeech_training.util.downloader import maybe_download
from util.downloader import SIMPLE_BAR from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
from os import path from deepspeech_training.util.text import Alphabet
from glob import glob from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -38,7 +30,7 @@ ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data # Conditionally extract data
@ -48,8 +40,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -62,12 +54,12 @@ def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0] ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix # Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav" wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename) _maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter(sample[1])
rows = [] rows = []
@ -94,7 +86,7 @@ def one_sample(sample):
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv')) target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
@ -160,7 +152,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:

Просмотреть файл

@ -2,29 +2,19 @@
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import subprocess import os
import progressbar import progressbar
import unicodedata import subprocess
import tarfile import tarfile
import unicodedata
from multiprocessing import Pool from deepspeech_training.util.downloader import maybe_download
from util.downloader import SIMPLE_BAR from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
from os import path from deepspeech_training.util.text import Alphabet
from glob import glob from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -37,7 +27,7 @@ ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data # Conditionally extract data
@ -48,8 +38,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -65,8 +55,8 @@ def one_sample(sample):
wav_filename = sample[0] wav_filename = sample[0]
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter(sample[1])
counter = get_counter() counter = get_counter()
@ -93,7 +83,7 @@ def one_sample(sample):
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv')) target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):

Просмотреть файл

@ -1,18 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import os
import pandas import pandas
import tarfile import tarfile
import wave import wave
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']

Просмотреть файл

@ -1,19 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import json import json
import numpy as np import numpy as np
import os
import pandas import pandas
import tarfile import tarfile
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']

Просмотреть файл

@ -1,32 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import os
import progressbar
import re import re
import sox import sox
import zipfile
import subprocess import subprocess
import progressbar
import unicodedata
import tarfile import tarfile
import unicodedata
import zipfile
from multiprocessing import Pool from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
from util.downloader import SIMPLE_BAR from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
from deepspeech_training.util.text import Alphabet
from os import path
from glob import glob from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -39,7 +29,7 @@ ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data # Conditionally extract data
@ -49,8 +39,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -65,8 +55,8 @@ def one_sample(sample):
wav_filename = sample[0] wav_filename = sample[0]
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter(sample[1])
counter = get_counter() counter = get_counter()
@ -92,7 +82,7 @@ def one_sample(sample):
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv')) target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):

Просмотреть файл

@ -1,28 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g. # ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
# ./data/swb/swb1_LDC97S62.tgz # ./data/swb/swb1_LDC97S62.tgz
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/ # from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
import codecs
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import fnmatch import fnmatch
import librosa
import os
import pandas import pandas
import requests
import soundfile # <= Has an external dependency on libsndfile
import subprocess import subprocess
import sys
import tarfile
import unicodedata import unicodedata
import wave import wave
import codecs
import tarfile from deepspeech_training.util.importers import validate_label_eng as validate_label
import requests
from util.importers import validate_label_eng as validate_label
import librosa
import soundfile # <= Has an external dependency on libsndfile
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03 # ARCHIVE_NAME refers to ISIP alignments from 01/29/03
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz' ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'

Просмотреть файл

@ -5,31 +5,26 @@ Use "python3 import_swc.py -h" for help
''' '''
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import re
import csv
import sox
import wave
import shutil
import random
import tarfile
import argparse import argparse
import csv
import os
import progressbar import progressbar
import random
import re
import shutil
import sox
import sys
import tarfile
import unicodedata import unicodedata
import wave
import xml.etree.cElementTree as ET import xml.etree.cElementTree as ET
from os import path
from glob import glob from glob import glob
from collections import Counter from collections import Counter
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from util.text import Alphabet from deepspeech_training.util.text import Alphabet
from util.importers import validate_label_eng as validate_label from deepspeech_training.util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar" SWC_ARCHIVE = "SWC_{language}.tar"
@ -117,8 +112,8 @@ def maybe_download_language(language):
def maybe_extract(data_dir, extracted_data, archive): def maybe_extract(data_dir, extracted_data, archive):
extracted = path.join(data_dir, extracted_data) extracted = os.path.join(data_dir, extracted_data)
if path.isdir(extracted): if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted)) print('Found directory "{}" - not extracting.'.format(extracted))
else: else:
print('Extracting "{}"...'.format(archive)) print('Extracting "{}"...'.format(archive))
@ -242,7 +237,7 @@ def collect_samples(base_dir, language):
print('Collecting samples...') print('Collecting samples...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots): for root in bar(roots):
wav_path = path.join(root, WAV_NAME) wav_path = os.path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME)) aligned = ET.parse(path.join(root, ALIGNED_NAME))
article = UNKNOWN article = UNKNOWN
speaker = UNKNOWN speaker = UNKNOWN
@ -294,8 +289,8 @@ def maybe_convert_one_to_wav(entry):
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner() combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = path.join(root, WAV_NAME) output_wav = os.path.join(root, WAV_NAME)
if path.isfile(output_wav): if os.path.isfile(output_wav):
return return
files = sorted(glob(path.join(root, AUDIO_PATTERN))) files = sorted(glob(path.join(root, AUDIO_PATTERN)))
try: try:
@ -304,7 +299,7 @@ def maybe_convert_one_to_wav(entry):
elif len(files) > 1: elif len(files) > 1:
wav_files = [] wav_files = []
for i, file in enumerate(files): for i, file in enumerate(files):
wav_path = path.join(root, 'audio{}.wav'.format(i)) wav_path = os.path.join(root, 'audio{}.wav'.format(i))
transformer.build(file, wav_path) transformer.build(file, wav_path)
wav_files.append(wav_path) wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files)) combiner.set_input_format(file_type=['wav'] * len(wav_files))
@ -358,8 +353,8 @@ def assign_sub_sets(samples):
def create_sample_dirs(language): def create_sample_dirs(language):
print('Creating sample directories...') print('Creating sample directories...')
for set_name in ['train', 'dev', 'test']: for set_name in ['train', 'dev', 'test']:
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name) dir_path = os.path.join(CLI_ARGS.base_dir, language + '-' + set_name)
if not path.isdir(dir_path): if not os.path.isdir(dir_path):
os.mkdir(dir_path) os.mkdir(dir_path)
@ -374,7 +369,7 @@ def split_audio_files(samples, language):
rate = src_wav_file.getframerate() rate = src_wav_file.getframerate()
for sample in file_samples: for sample in file_samples:
index = sub_sets[sample.sub_set] index = sub_sets[sample.sub_set]
sample_wav_path = path.join(CLI_ARGS.base_dir, sample_wav_path = os.path.join(CLI_ARGS.base_dir,
language + '-' + sample.sub_set, language + '-' + sample.sub_set,
'sample-{0:06d}.wav'.format(index)) 'sample-{0:06d}.wav'.format(index))
sample.wav_path = sample_wav_path sample.wav_path = sample_wav_path
@ -391,8 +386,8 @@ def split_audio_files(samples, language):
def write_csvs(samples, language): def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items(): for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path) set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = path.abspath(CLI_ARGS.base_dir) base_dir = os.path.abspath(CLI_ARGS.base_dir)
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv') csv_path = os.path.join(base_dir, language + '-' + sub_set + '.csv')
print('Writing "{}"...'.format(csv_path)) print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file: with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES) writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
@ -400,8 +395,8 @@ def write_csvs(samples, language):
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
for sample in bar(set_samples): for sample in bar(set_samples):
row = { row = {
'wav_filename': path.relpath(sample.wav_path, base_dir), 'wav_filename': os.path.relpath(sample.wav_path, base_dir),
'wav_filesize': path.getsize(sample.wav_path), 'wav_filesize': os.path.getsize(sample.wav_path),
'transcript': sample.text 'transcript': sample.text
} }
if CLI_ARGS.add_meta: if CLI_ARGS.add_meta:
@ -414,8 +409,8 @@ def cleanup(archive, language):
if not CLI_ARGS.keep_archive: if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive)) print('Removing archive "{}"...'.format(archive))
os.remove(archive) os.remove(archive)
language_dir = path.join(CLI_ARGS.base_dir, language) language_dir = os.path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir): if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir)) print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir) shutil.rmtree(language_dir)

Просмотреть файл

@ -1,14 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import pandas import pandas
import sys
import tarfile import tarfile
import unicodedata import unicodedata
import wave import wave
@ -16,9 +10,10 @@ import wave
from glob import glob from glob import glob
from os import makedirs, path, remove, rmdir from os import makedirs, path, remove, rmdir
from sox import Transformer from sox import Transformer
from util.downloader import maybe_download from deepspeech_training.util.downloader import maybe_download
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from util.stm import parse_stm_file from deepspeech_training.util.stm import parse_stm_file
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data # Conditionally download data

Просмотреть файл

@ -1,28 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import unidecode import os
import zipfile import progressbar
import re
import sox import sox
import subprocess import subprocess
import progressbar import unidecode
import zipfile
from deepspeech_training.util.downloader import maybe_download
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
from multiprocessing import Pool from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from util.downloader import maybe_download
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -34,7 +26,7 @@ ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE
def _download_and_preprocess_data(target_dir, english_compatible=False): def _download_and_preprocess_data(target_dir, english_compatible=False):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL) archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
# Conditionally extract archive data # Conditionally extract archive data
@ -45,8 +37,8 @@ def _download_and_preprocess_data(target_dir, english_compatible=False):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -60,12 +52,12 @@ def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path'] orig_filename = sample['path']
# Storing wav files next to the wav ones - just with a different suffix # Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav" wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename) _maybe_convert_wav(orig_filename, wav_filename)
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text'] label = sample['text']
@ -95,7 +87,7 @@ def one_sample(sample):
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv') target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
@ -160,7 +152,7 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(orig_filename, wav_filename): def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:

Просмотреть файл

@ -5,25 +5,19 @@ Use "python3 import_tuda.py -h" for help
''' '''
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import wave
import tarfile
import argparse import argparse
import csv
import os
import progressbar import progressbar
import tarfile
import unicodedata import unicodedata
import wave
import xml.etree.cElementTree as ET import xml.etree.cElementTree as ET
from os import path
from collections import Counter from collections import Counter
from util.text import Alphabet from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
from util.importers import validate_label_eng as validate_label from deepspeech_training.util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR from deepspeech_training.util.text import Alphabet
TUDA_VERSION = 'v2' TUDA_VERSION = 'v2'
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION) TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
@ -38,8 +32,8 @@ FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
def maybe_extract(archive): def maybe_extract(archive):
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE) extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if path.isdir(extracted): if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted)) print('Found directory "{}" - not extracting.'.format(extracted))
else: else:
print('Extracting "{}"...'.format(archive)) print('Extracting "{}"...'.format(archive))
@ -92,7 +86,7 @@ def write_csvs(extracted):
sample_counter = 0 sample_counter = 0
reasons = Counter() reasons = Counter()
for sub_set in ['train', 'dev', 'test']: for sub_set in ['train', 'dev', 'test']:
set_path = path.join(extracted, sub_set) set_path = os.path.join(extracted, sub_set)
set_files = os.listdir(set_path) set_files = os.listdir(set_path)
recordings = {} recordings = {}
for file in set_files: for file in set_files:
@ -104,15 +98,15 @@ def write_csvs(extracted):
if prefix in recordings: if prefix in recordings:
recordings[prefix].append(file) recordings[prefix].append(file)
recordings = recordings.items() recordings = recordings.items()
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set)) csv_path = os.path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
print('Writing "{}"...'.format(csv_path)) print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file: with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
set_dir = path.join(extracted, sub_set) set_dir = os.path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings): for prefix, wav_names in bar(recordings):
xml_path = path.join(set_dir, prefix + '.xml') xml_path = os.path.join(set_dir, prefix + '.xml')
meta = ET.parse(xml_path).getroot() meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter('cleaned_sentence'))[0].text sentence = list(meta.iter('cleaned_sentence'))[0].text
sentence = check_and_prepare_sentence(sentence) sentence = check_and_prepare_sentence(sentence)
@ -120,12 +114,12 @@ def write_csvs(extracted):
continue continue
for wav_name in wav_names: for wav_name in wav_names:
sample_counter += 1 sample_counter += 1
wav_path = path.join(set_path, wav_name) wav_path = os.path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence) keep, reason = check_wav_file(wav_path, sentence)
if keep: if keep:
writer.writerow({ writer.writerow({
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir), 'wav_filename': os.path.relpath(wav_path, CLI_ARGS.base_dir),
'wav_filesize': path.getsize(wav_path), 'wav_filesize': os.path.getsize(wav_path),
'transcript': sentence.lower() 'transcript': sentence.lower()
}) })
else: else:

Просмотреть файл

@ -1,30 +1,21 @@
#!/usr/bin/env python #!/usr/bin/env python
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf # VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0. # Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html # as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import random
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re
import librosa import librosa
import os
import progressbar import progressbar
import random
import re
from os import path from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
from deepspeech_training.util.importers import get_counter, get_imported_samples, print_import_report
from multiprocessing import Pool from multiprocessing import Pool
from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile from zipfile import ZipFile
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
MIN_SECS = 1 MIN_SECS = 1
@ -37,7 +28,7 @@ ARCHIVE_URL = (
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data # Conditionally extract common voice data
@ -48,8 +39,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print(f"No directory {extracted_path} - extracting archive...") print(f"No directory {extracted_path} - extracting archive...")
with ZipFile(archive_path, "r") as zipobj: with ZipFile(archive_path, "r") as zipobj:
# Extract all the contents of zip file in current directory # Extract all the contents of zip file in current directory
@ -59,8 +50,8 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48") extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt") txt_dir = os.path.join(target_dir, extracted_data, "txt")
directory = os.path.expanduser(extracted_dir) directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory))) srtd = len(sorted(os.listdir(directory)))

Просмотреть файл

@ -2,23 +2,20 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import codecs import codecs
import sys
import os import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import tarfile
import pandas import pandas
import re import re
import unicodedata import tarfile
import threading import threading
from multiprocessing.pool import ThreadPool import unicodedata
from six.moves import urllib
from glob import glob
from os import makedirs, path
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from deepspeech_training.util.downloader import maybe_download
from glob import glob
from multiprocessing.pool import ThreadPool
from os import makedirs, path
from six.moves import urllib
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from util.downloader import maybe_download
"""The number of jobs to run in parallel""" """The number of jobs to run in parallel"""
NUM_PARALLEL = 8 NUM_PARALLEL = 8
@ -99,7 +96,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
dataset_dir = path.join(data_dir, "dev") dataset_dir = path.join(data_dir, "dev")
else: else:
dataset_dir = path.join(data_dir, "train") dataset_dir = path.join(data_dir, "train")
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))): if not gfile.Exists(os.path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
c = counter.increment() c = counter.increment()
print('Extracting file {} ({}/{})...'.format(i+1, c, total)) print('Extracting file {} ({}/{})...'.format(i+1, c, total))
tar = tarfile.open(archive) tar = tarfile.open(archive)
@ -132,14 +129,14 @@ def _download_and_preprocess_data(data_dir):
p.map(downloader, enumerate(refs)) p.map(downloader, enumerate(refs))
# Conditionally extract data to dataset_dir # Conditionally extract data to dataset_dir
if not path.isdir(path.join(data_dir,"test")): if not path.isdir(os.path.join(data_dir, "test")):
makedirs(path.join(data_dir,"test")) makedirs(os.path.join(data_dir, "test"))
if not path.isdir(path.join(data_dir,"dev")): if not path.isdir(os.path.join(data_dir, "dev")):
makedirs(path.join(data_dir,"dev")) makedirs(os.path.join(data_dir, "dev"))
if not path.isdir(path.join(data_dir,"train")): if not path.isdir(os.path.join(data_dir, "train")):
makedirs(path.join(data_dir,"train")) makedirs(os.path.join(data_dir, "train"))
tarfiles = glob(path.join(archive_dir, "*.tgz")) tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
number_of_files = len(tarfiles) number_of_files = len(tarfiles)
number_of_test = number_of_files//100 number_of_test = number_of_files//100
number_of_dev = number_of_files//100 number_of_dev = number_of_files//100
@ -156,20 +153,20 @@ def _download_and_preprocess_data(data_dir):
train_files = _generate_dataset(data_dir, "train") train_files = _generate_dataset(data_dir, "train")
# Write sets to disk as CSV files # Write sets to disk as CSV files
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False) train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False) dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False) test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
def _generate_dataset(data_dir, data_set): def _generate_dataset(data_dir, data_set):
extracted_dir = path.join(data_dir, data_set) extracted_dir = path.join(data_dir, data_set)
files = [] files = []
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")): for promts_file in glob(os.path.join(extracted_dir+"/*/etc/", "PROMPTS")):
if path.isdir(path.join(promts_file[:-11],"wav")): if path.isdir(os.path.join(promts_file[:-11], "wav")):
with codecs.open(promts_file, 'r', 'utf-8') as f: with codecs.open(promts_file, 'r', 'utf-8') as f:
for line in f: for line in f:
id = line.split(' ')[0].split('/')[-1] id = line.split(' ')[0].split('/')[-1]
sentence = ' '.join(line.split(' ')[1:]) sentence = ' '.join(line.split(' ')[1:])
sentence = re.sub("[^a-z']"," ",sentence.strip().lower()) sentence = re.sub("[^a-z']", " ",sentence.strip().lower())
transcript = "" transcript = ""
for token in sentence.split(" "): for token in sentence.split(" "):
word = token.strip() word = token.strip()
@ -178,14 +175,14 @@ def _generate_dataset(data_dir, data_set):
transcript = unicodedata.normalize("NFKD", transcript.strip()) \ transcript = unicodedata.normalize("NFKD", transcript.strip()) \
.encode("ascii", "ignore") \ .encode("ascii", "ignore") \
.decode("ascii", "ignore") .decode("ascii", "ignore")
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav") wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
if gfile.Exists(wav_file): if gfile.Exists(wav_file):
wav_filesize = path.getsize(wav_file) wav_filesize = path.getsize(wav_file)
# remove audios that are shorter than 0.5s and longer than 20s. # remove audios that are shorter than 0.5s and longer than 20s.
# remove audios that are too short for transcript. # remove audios that are too short for transcript.
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \ if ((wav_filesize/32000) > 0.5 and (wav_filesize/32000) < 20 and transcript != "" and
wav_filesize/len(transcript)>1400: wav_filesize/len(transcript) > 1400):
files.append((path.abspath(wav_file), wav_filesize, transcript)) files.append((os.path.abspath(wav_file), wav_filesize, transcript))
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])

Просмотреть файл

@ -5,17 +5,12 @@ Use "python3 build_sdb.py -h" for help
""" """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse import argparse
import random
import sys
from util.sample_collections import samples_from_file, LabeledSample from deepspeech_training.util.audio import AUDIO_TYPE_PCM
from util.audio import AUDIO_TYPE_PCM from deepspeech_training.util.sample_collections import samples_from_file, LabeledSample
def play_sample(samples, index): def play_sample(samples, index):

Просмотреть файл

@ -1,17 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
import argparse import argparse
import shutil import shutil
import sys
from util.text import Alphabet, UTF8Alphabet from deepspeech_training.util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet

Просмотреть файл

@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
.. code-block:: .. code-block::
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/ $ python3 -m venv $HOME/tmp/deepspeech-train-venv/
Once this command completes successfully, the environment will be ready to be activated. Once this command completes successfully, the environment will be ready to be activated.
@ -46,7 +46,7 @@ Install the required dependencies using ``pip3``\ :
.. code-block:: bash .. code-block:: bash
cd DeepSpeech cd DeepSpeech
pip3 install -r requirements.txt pip3 install -e .
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules: The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
@ -70,7 +70,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
.. code-block:: bash .. code-block:: bash
pip3 uninstall tensorflow pip3 uninstall tensorflow
pip3 install 'tensorflow-gpu==1.15.0' pip3 install 'tensorflow-gpu==1.15.2'
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_. Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.

158
evaluate.py Executable file → Normal file
Просмотреть файл

@ -2,155 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.evaluate_tools import calculate_and_print_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
if __name__ == '__main__': if __name__ == '__main__':
create_flags() try:
absl.app.run(main) from deepspeech_training import evaluate as ds_evaluate
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_evaluate.run_script()

Просмотреть файл

@ -10,13 +10,12 @@ import csv
import os import os
import sys import sys
from functools import partial
from six.moves import zip, range
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from deepspeech import Model from deepspeech import Model
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
from util.evaluate_tools import calculate_and_print_report from deepspeech_training.util.flags import create_flags
from util.flags import create_flags from functools import partial
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from six.moves import zip, range
r''' r'''
This module should be self-contained: This module should be self-contained:

Просмотреть файл

@ -2,19 +2,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import sys
import optuna
import absl.app import absl.app
from ds_ctcdecoder import Scorer import optuna
import sys
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from DeepSpeech import create_model from deepspeech_training.evaluate import evaluate
from evaluate import evaluate from deepspeech_training.train import create_model
from util.config import Config, initialize_globals from deepspeech_training.util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS from deepspeech_training.util.flags import create_flags, FLAGS
from util.logging import log_error from deepspeech_training.util.logging import log_error
from util.evaluate_tools import wer_cer_batch from deepspeech_training.util.evaluate_tools import wer_cer_batch
from ds_ctcdecoder import Scorer
def character_based(): def character_based():

Просмотреть файл

@ -1,24 +0,0 @@
# Main training requirements
tensorflow == 1.15.2
numpy == 1.18.1
progressbar2
six
pyxdg
attrdict
absl-py
semver
opuslib == 2.0.0
# Requirements for building native_client files
setuptools
# Requirements for importers
sox
bs4
pandas
requests
librosa
soundfile
# Requirements for optimizer
optuna

53
setup.py Normal file
Просмотреть файл

@ -0,0 +1,53 @@
from setuptools import setup, find_packages
def main():
setup(
name='deepspeech_training',
version='0.0.1',
description='Training code for mozilla DeepSpeech',
url='https://github.com/mozilla/DeepSpeech',
author='Mozilla',
license='MPL-2.0',
# Classifiers help users find your project by categorizing it.
#
# For a list of valid classifiers, see https://pypi.org/classifiers/
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Topic :: Multimedia :: Sound/Audio :: Speech',
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
'Programming Language :: Python :: 3',
],
package_dir={'': 'training'},
packages=find_packages(where='training'),
python_requires='>=3.5, <4',
install_requires=[
'tensorflow == 1.15.2',
'numpy == 1.18.1',
'progressbar2',
'six',
'pyxdg',
'attrdict',
'absl-py',
'semver',
'opuslib == 2.0.0',
'optuna',
'sox',
'bs4',
'pandas',
'requests',
'librosa',
'soundfile',
],
# If there are data files included in your packages that need to be
# installed, specify them here.
package_data={
'deepspeech_training': [
'VERSION',
'GRAPH_VERSION',
],
},
)
if __name__ == '__main__':
main()

Просмотреть файл

@ -1,10 +1,29 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import os import functools
import pandas
from deepspeech_training.util.helpers import secs_to_hours
from pathlib import Path
def read_csvs(csv_files):
# Relative paths are relative to CSV location
def absolutify(csv, path):
path = Path(path)
if path.is_absolute():
return str(path)
return str(csv.parent / path)
sets = []
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True)
from util.helpers import secs_to_hours
from util.feeding import read_csvs
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -14,20 +33,16 @@ def main():
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels") parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample") parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
args = parser.parse_args() args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")] in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
csv_dataframe = read_csvs(in_files) csv_dataframe = read_csvs(in_files)
total_bytes = csv_dataframe['wav_filesize'].sum() total_bytes = csv_dataframe['wav_filesize'].sum()
total_files = len(csv_dataframe.index) total_files = len(csv_dataframe)
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
bytes_without_headers = total_bytes - 44 * total_files print('Total bytes:', total_bytes)
print('Total files:', total_files)
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8) print('Total time:', secs_to_hours(total_seconds))
print('total_bytes', total_bytes)
print('total_files', total_files)
print('bytes_without_headers', bytes_without_headers)
print('total_time', secs_to_hours(total_time))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

Просмотреть файл

@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
set -o pipefail set -o pipefail
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
which deepspeech which deepspeech

Просмотреть файл

@ -17,7 +17,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}") decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")

Просмотреть файл

@ -16,7 +16,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/

Просмотреть файл

@ -14,7 +14,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -1,10 +1,14 @@
import unittest import unittest
from argparse import Namespace from argparse import Namespace
from .importers import validate_label_eng, get_validate_label from deepspeech_training.util.importers import validate_label_eng, get_validate_label
from pathlib import Path
def from_here(path):
here = Path(__file__)
return here.parent / path
class TestValidateLabelEng(unittest.TestCase): class TestValidateLabelEng(unittest.TestCase):
def test_numbers(self): def test_numbers(self):
label = validate_label_eng("this is a 1 2 3 test") label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None) self.assertEqual(label, None)
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
self.assertEqual(f('toto1234[{[{[]'), None) self.assertEqual(f('toto1234[{[{[]'), None)
def test_get_validate_label_missing(self): def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py') args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
f = get_validate_label(args) f = get_validate_label(args)
self.assertEqual(f, None) self.assertEqual(f, None)
def test_get_validate_label(self): def test_get_validate_label(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py') args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
f = get_validate_label(args) f = get_validate_label(args)
l = f('toto') l = f('toto')
self.assertEqual(l, 'toto') self.assertEqual(l, 'toto')

Просмотреть файл

@ -1,7 +1,7 @@
import unittest import unittest
import os import os
from .text import Alphabet from deepspeech_training.util.text import Alphabet
class TestAlphabetParsing(unittest.TestCase): class TestAlphabetParsing(unittest.TestCase):

Просмотреть файл

@ -0,0 +1 @@
../../GRAPH_VERSION

Просмотреть файл

@ -0,0 +1 @@
../../VERSION

Просмотреть файл

Просмотреть файл

@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version
from .util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

Просмотреть файл

@ -0,0 +1,936 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

Просмотреть файл

Просмотреть файл

@ -5,7 +5,7 @@ import tempfile
import collections import collections
import numpy as np import numpy as np
from util.helpers import LimitingPool from .helpers import LimitingPool
DEFAULT_RATE = 16000 DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1 DEFAULT_CHANNELS = 1

Просмотреть файл

@ -2,8 +2,8 @@ import sys
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from util.flags import FLAGS from .flags import FLAGS
from util.logging import log_info, log_error, log_warn from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path): def _load_checkpoint(session, checkpoint_path):

Просмотреть файл

@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict from attrdict import AttrDict
from xdg import BaseDirectory as xdg from xdg import BaseDirectory as xdg
from util.flags import FLAGS from .flags import FLAGS
from util.gpu import get_available_gpus from .gpu import get_available_gpus
from util.logging import log_error from .logging import log_error
from util.text import Alphabet, UTF8Alphabet from .text import Alphabet, UTF8Alphabet
from util.helpers import parse_file_size from .helpers import parse_file_size
class ConfigSingleton: class ConfigSingleton:
_config = None _config = None

Просмотреть файл

Просмотреть файл

@ -7,8 +7,8 @@ import numpy as np
from attrdict import AttrDict from attrdict import AttrDict
from util.flags import FLAGS from .flags import FLAGS
from util.text import levenshtein from .text import levenshtein
def pmap(fun, iterable): def pmap(fun, iterable):

Просмотреть файл

@ -8,13 +8,13 @@ import tensorflow as tf
from tensorflow.python.ops import gen_audio_ops as contrib_audio from tensorflow.python.ops import gen_audio_ops as contrib_audio
from util.config import Config from .config import Config
from util.text import text_to_char_array from .text import text_to_char_array
from util.flags import FLAGS from .flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
from util.sample_collections import samples_from_files from .sample_collections import samples_from_files
from util.helpers import remember_exception, MEGABYTE from .helpers import remember_exception, MEGABYTE
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None): def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -4,7 +4,7 @@ import os
import re import re
import sys import sys
from util.helpers import secs_to_hours from .helpers import secs_to_hours
from collections import Counter from collections import Counter
def get_counter(): def get_counter():

Просмотреть файл

@ -3,7 +3,7 @@ from __future__ import print_function
import progressbar import progressbar
import sys import sys
from util.flags import FLAGS from .flags import FLAGS
# Logging functions # Logging functions

Просмотреть файл

@ -5,8 +5,8 @@ import json
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from util.helpers import MEGABYTE, GIGABYTE, Interleaved from .helpers import MEGABYTE, GIGABYTE, Interleaved
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
BIG_ENDIAN = 'big' BIG_ENDIAN = 'big'
INT_SIZE = 4 INT_SIZE = 4

Просмотреть файл

@ -1,6 +1,7 @@
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from util.sparse_image_warp import sparse_image_warp
from .sparse_image_warp import sparse_image_warp
def augment_freq_time_mask(spectrogram, def augment_freq_time_mask(spectrogram,
frequency_masking_para=30, frequency_masking_para=30,

Просмотреть файл

Просмотреть файл

@ -0,0 +1,167 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import argparse
import errno
import gzip
import os
import platform
import six.moves.urllib as urllib
import stat
import subprocess
import sys
from pkg_resources import parse_version
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
}
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert artifact_name
assert branch_name is not None
assert branch_name
return TASKCLUSTER_SCHEME % {'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
def maybe_download_tc(target_dir, tc_url, progress=True):
def report_progress(count, block_size, total_size):
percent = (count * block_size * 100) // total_size
sys.stdout.write("\rDownloading: %d%%" % percent)
sys.stdout.flush()
if percent >= 100:
print('\n')
assert target_dir is not None
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
is_gzip = False
if not os.path.isfile(target_file):
print('Downloading %s ...' % tc_url)
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
is_gzip = headers.get('Content-Encoding') == 'gzip'
else:
print('File already exists: %s' % target_file)
if is_gzip:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)
frw.truncate()
return target_file
def maybe_download_tc_bin(**kwargs):
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
final_stat = os.stat(final_file)
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
parser.add_argument('--arch', required=False,
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
parser.add_argument('--artifact', required=False,
default='native_client.tar.xz',
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
parser.add_argument('--source', required=False, default=None,
help='Name of the TaskCluster scheme to use.')
parser.add_argument('--branch', required=False,
help='Branch name to use. Defaulting to current content of VERSION file.')
parser.add_argument('--decoder', action='store_true',
help='Get URL to ds_ctcdecoder Python package.')
args = parser.parse_args()
if not args.target and not args.decoder:
print('Pass either --target or --decoder.')
sys.exit(1)
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
is_ucs2 = sys.maxunicode < 0x10ffff
if not args.arch:
if is_arm:
args.arch = 'arm64' if is_64bit else 'arm'
elif is_mac:
args.arch = 'osx'
else:
args.arch = 'cpu'
if not args.branch:
version_string = read('../VERSION').strip()
ds_version = parse_version(version_string)
args.branch = "v{}".format(version_string)
else:
ds_version = parse_version(args.branch)
if args.decoder:
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(map(str, sys.version_info[0:2]))
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch
)
ctc_arch = args.arch + '-ctc'
print(get_tc_url(ctc_arch, artifact, args.branch))
sys.exit(0)
if args.source is not None:
if args.source in DEFAULT_SCHEMES:
global TASKCLUSTER_SCHEME
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
else:
print('No such scheme: %s' % args.source)
sys.exit(1)
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
if args.artifact == "convert_graphdef_memmapped_format":
convert_graph_file = os.path.join(args.target, args.artifact)
final_stat = os.stat(convert_graph_file)
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
if __name__ == '__main__':
main()

Просмотреть файл

Просмотреть файл

@ -12,13 +12,13 @@ tflogging.set_verbosity(tflogging.ERROR)
import logging import logging
logging.getLogger('sox').setLevel(logging.ERROR) logging.getLogger('sox').setLevel(logging.ERROR)
from multiprocessing import Process, cpu_count from deepspeech_training.util.audio import AudioFile
from deepspeech_training.util.config import Config, initialize_globals
from deepspeech_training.util.feeding import split_audio_file
from deepspeech_training.util.flags import create_flags, FLAGS
from deepspeech_training.util.logging import log_error, log_info, log_progress, create_progressbar
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import Config, initialize_globals from multiprocessing import Process, cpu_count
from util.audio import AudioFile
from util.feeding import split_audio_file
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_info, log_progress, create_progressbar
def fail(message, code=1): def fail(message, code=1):
@ -27,8 +27,8 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path): def transcribe_file(audio_path, tlog_path):
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from util.checkpoints import load_or_init_graph from deepspeech_training.util.checkpoints import load_or_init_graph
initialize_globals() initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try: try:

Просмотреть файл

@ -1,168 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division from __future__ import absolute_import, division, print_function
import argparse
import platform
import subprocess
import sys
import os
import errno
import stat
import gzip
import six.moves.urllib as urllib
from pkg_resources import parse_version
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
}
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert artifact_name
assert branch_name is not None
assert branch_name
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
def maybe_download_tc(target_dir, tc_url, progress=True):
def report_progress(count, block_size, total_size):
percent = (count * block_size * 100) // total_size
sys.stdout.write("\rDownloading: %d%%" % percent)
sys.stdout.flush()
if percent >= 100:
print('\n')
assert target_dir is not None
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
is_gzip = False
if not os.path.isfile(target_file):
print('Downloading %s ...' % tc_url)
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
is_gzip = headers.get('Content-Encoding') == 'gzip'
else:
print('File already exists: %s' % target_file)
if is_gzip:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)
frw.truncate()
return target_file
def maybe_download_tc_bin(**kwargs):
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
final_stat = os.stat(final_file)
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
parser.add_argument('--arch', required=False,
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
parser.add_argument('--artifact', required=False,
default='native_client.tar.xz',
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
parser.add_argument('--source', required=False, default=None,
help='Name of the TaskCluster scheme to use.')
parser.add_argument('--branch', required=False,
help='Branch name to use. Defaulting to current content of VERSION file.')
parser.add_argument('--decoder', action='store_true',
help='Get URL to ds_ctcdecoder Python package.')
args = parser.parse_args()
if not args.target and not args.decoder:
print('Pass either --target or --decoder.')
exit(1)
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
is_ucs2 = sys.maxunicode < 0x10ffff
if not args.arch:
if is_arm:
args.arch = 'arm64' if is_64bit else 'arm'
elif is_mac:
args.arch = 'osx'
else:
args.arch = 'cpu'
if not args.branch:
version_string = read('../VERSION').strip()
ds_version = parse_version(version_string)
args.branch = "v{}".format(version_string)
else:
ds_version = parse_version(args.branch)
if args.decoder:
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(map(str, sys.version_info[0:2]))
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch
)
ctc_arch = args.arch + '-ctc'
print(get_tc_url(ctc_arch, artifact, args.branch))
exit(0)
if args.source is not None:
if args.source in DEFAULT_SCHEMES:
global TASKCLUSTER_SCHEME
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
else:
print('No such scheme: %s' % args.source)
exit(1)
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
if args.artifact == "convert_graphdef_memmapped_format":
convert_graph_file = os.path.join(args.target, args.artifact)
final_stat = os.stat(convert_graph_file)
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
if __name__ == '__main__': if __name__ == '__main__':
main() try:
from deepspeech_training.util import taskcluster as dsu_taskcluster
except ImportError:
print('Training package is not installed. See training documentation.')
raise
dsu_taskcluster.main()