зеркало из https://github.com/mozilla/DeepSpeech.git
Package training code to avoid sys.path hacks
This commit is contained in:
Родитель
58bc2f2bb1
Коммит
a05baa35c9
|
@ -9,7 +9,7 @@ python:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
include:
|
include:
|
||||||
- stage: cardboard linter
|
- name: cardboard linter
|
||||||
install:
|
install:
|
||||||
- pip install --upgrade cardboardlint pylint
|
- pip install --upgrade cardboardlint pylint
|
||||||
script:
|
script:
|
||||||
|
@ -17,9 +17,10 @@ jobs:
|
||||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||||
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
|
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
|
||||||
fi
|
fi
|
||||||
- stage: python unit tests
|
- name: python unit tests
|
||||||
install:
|
install:
|
||||||
- pip install --upgrade -r requirements_tests.txt
|
- pip install --upgrade -r requirements_tests.txt;
|
||||||
|
pip install --upgrade .
|
||||||
script:
|
script:
|
||||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||||
python -m unittest;
|
python -m unittest;
|
||||||
|
|
937
DeepSpeech.py
937
DeepSpeech.py
|
@ -2,934 +2,11 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
|
||||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
|
||||||
|
|
||||||
import absl.app
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import progressbar
|
|
||||||
import shutil
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow.compat.v1 as tfv1
|
|
||||||
import time
|
|
||||||
|
|
||||||
tfv1.logging.set_verbosity({
|
|
||||||
'0': tfv1.logging.DEBUG,
|
|
||||||
'1': tfv1.logging.INFO,
|
|
||||||
'2': tfv1.logging.WARN,
|
|
||||||
'3': tfv1.logging.ERROR
|
|
||||||
}.get(DESIRED_LOG_LEVEL))
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
|
||||||
from evaluate import evaluate
|
|
||||||
from six.moves import zip, range
|
|
||||||
from util.config import Config, initialize_globals
|
|
||||||
from util.checkpoints import load_or_init_graph
|
|
||||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
|
||||||
from util.flags import create_flags, FLAGS
|
|
||||||
from util.helpers import check_ctcdecoder_version, ExceptionBox
|
|
||||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
|
||||||
|
|
||||||
check_ctcdecoder_version()
|
|
||||||
|
|
||||||
# Graph Creation
|
|
||||||
# ==============
|
|
||||||
|
|
||||||
def variable_on_cpu(name, shape, initializer):
|
|
||||||
r"""
|
|
||||||
Next we concern ourselves with graph creation.
|
|
||||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
|
||||||
used to create a variable in CPU memory.
|
|
||||||
"""
|
|
||||||
# Use the /cpu:0 device for scoped operations
|
|
||||||
with tf.device(Config.cpu_device):
|
|
||||||
# Create or get apropos variable
|
|
||||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
|
||||||
return var
|
|
||||||
|
|
||||||
|
|
||||||
def create_overlapping_windows(batch_x):
|
|
||||||
batch_size = tf.shape(input=batch_x)[0]
|
|
||||||
window_width = 2 * Config.n_context + 1
|
|
||||||
num_channels = Config.n_input
|
|
||||||
|
|
||||||
# Create a constant convolution filter using an identity matrix, so that the
|
|
||||||
# convolution returns patches of the input tensor as is, and we can create
|
|
||||||
# overlapping windows over the MFCCs.
|
|
||||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
|
||||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
|
||||||
|
|
||||||
# Create overlapping windows
|
|
||||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
|
||||||
|
|
||||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
|
||||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
|
||||||
|
|
||||||
return batch_x
|
|
||||||
|
|
||||||
|
|
||||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
|
||||||
with tfv1.variable_scope(name):
|
|
||||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
|
||||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
|
||||||
|
|
||||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
|
||||||
|
|
||||||
if relu:
|
|
||||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
|
||||||
|
|
||||||
if dropout_rate is not None:
|
|
||||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
|
||||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
|
||||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
|
||||||
forget_bias=0,
|
|
||||||
reuse=reuse,
|
|
||||||
name='cudnn_compatible_lstm_cell')
|
|
||||||
|
|
||||||
output, output_state = fw_cell(inputs=x,
|
|
||||||
dtype=tf.float32,
|
|
||||||
sequence_length=seq_length,
|
|
||||||
initial_state=previous_state)
|
|
||||||
|
|
||||||
return output, output_state
|
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
|
||||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
|
||||||
|
|
||||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
|
||||||
# the object it creates the variables, and then you just call it several times
|
|
||||||
# to enable variable re-use. Because all of our code is structure in an old
|
|
||||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
|
||||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
|
||||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
|
||||||
# emulating a static function variable.
|
|
||||||
if not rnn_impl_cudnn_rnn.cell:
|
|
||||||
# Forward direction cell:
|
|
||||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
|
||||||
num_units=Config.n_cell_dim,
|
|
||||||
input_mode='linear_input',
|
|
||||||
direction='unidirectional',
|
|
||||||
dtype=tf.float32)
|
|
||||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
|
||||||
|
|
||||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
|
||||||
sequence_lengths=seq_length)
|
|
||||||
|
|
||||||
return output, output_state
|
|
||||||
|
|
||||||
rnn_impl_cudnn_rnn.cell = None
|
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
|
||||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
|
||||||
# Forward direction cell:
|
|
||||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
|
||||||
forget_bias=0,
|
|
||||||
reuse=reuse,
|
|
||||||
name='cudnn_compatible_lstm_cell')
|
|
||||||
|
|
||||||
# Split rank N tensor into list of rank N-1 tensors
|
|
||||||
x = [x[l] for l in range(x.shape[0])]
|
|
||||||
|
|
||||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
|
||||||
inputs=x,
|
|
||||||
sequence_length=seq_length,
|
|
||||||
initial_state=previous_state,
|
|
||||||
dtype=tf.float32,
|
|
||||||
scope='cell_0')
|
|
||||||
|
|
||||||
output = tf.concat(output, 0)
|
|
||||||
|
|
||||||
return output, output_state
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
|
||||||
layers = {}
|
|
||||||
|
|
||||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
|
||||||
if not batch_size:
|
|
||||||
batch_size = tf.shape(input=batch_x)[0]
|
|
||||||
|
|
||||||
# Create overlapping feature windows if needed
|
|
||||||
if overlap:
|
|
||||||
batch_x = create_overlapping_windows(batch_x)
|
|
||||||
|
|
||||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
|
||||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
|
||||||
|
|
||||||
# Permute n_steps and batch_size
|
|
||||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
|
||||||
# Reshape to prepare input for first layer
|
|
||||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
|
||||||
layers['input_reshaped'] = batch_x
|
|
||||||
|
|
||||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
|
||||||
# clipped RELU activation and dropout.
|
|
||||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
|
||||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
|
||||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
|
||||||
|
|
||||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
|
||||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
|
||||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
|
||||||
|
|
||||||
# Run through parametrized RNN implementation, as we use different RNNs
|
|
||||||
# for training and inference
|
|
||||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
|
||||||
|
|
||||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
|
||||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
|
||||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
|
||||||
layers['rnn_output'] = output
|
|
||||||
layers['rnn_output_state'] = output_state
|
|
||||||
|
|
||||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
|
||||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
|
||||||
|
|
||||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
|
||||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
|
||||||
|
|
||||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
|
||||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
|
||||||
# Note, that this differs from the input in that it is time-major.
|
|
||||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
|
||||||
layers['raw_logits'] = layer_6
|
|
||||||
|
|
||||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
|
||||||
return layer_6, layers
|
|
||||||
|
|
||||||
|
|
||||||
# Accuracy and Loss
|
|
||||||
# =================
|
|
||||||
|
|
||||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
|
||||||
# (http://arxiv.org/abs/1412.5567),
|
|
||||||
# the loss function used by our network should be the CTC loss function
|
|
||||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
|
||||||
# Conveniently, this loss function is implemented in TensorFlow.
|
|
||||||
# Thus, we can simply make use of this implementation to define our loss.
|
|
||||||
|
|
||||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
|
||||||
r'''
|
|
||||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
|
||||||
Next to total and average loss it returns the mean edit distance,
|
|
||||||
the decoded result and the batch's original Y.
|
|
||||||
'''
|
|
||||||
# Obtain the next batch of data
|
|
||||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
|
||||||
|
|
||||||
if FLAGS.train_cudnn:
|
|
||||||
rnn_impl = rnn_impl_cudnn_rnn
|
|
||||||
else:
|
|
||||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
|
||||||
|
|
||||||
# Calculate the logits of the batch
|
|
||||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
|
||||||
|
|
||||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
|
||||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
|
||||||
|
|
||||||
# Check if any files lead to non finite loss
|
|
||||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
|
||||||
|
|
||||||
# Calculate the average loss across the batch
|
|
||||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
|
||||||
|
|
||||||
# Finally we return the average loss
|
|
||||||
return avg_loss, non_finite_files
|
|
||||||
|
|
||||||
|
|
||||||
# Adam Optimization
|
|
||||||
# =================
|
|
||||||
|
|
||||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
|
||||||
# (http://arxiv.org/abs/1412.5567),
|
|
||||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
|
||||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
|
||||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
|
||||||
# because, generally, it requires less fine-tuning.
|
|
||||||
def create_optimizer(learning_rate_var):
|
|
||||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
|
||||||
beta1=FLAGS.beta1,
|
|
||||||
beta2=FLAGS.beta2,
|
|
||||||
epsilon=FLAGS.epsilon)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
# Towers
|
|
||||||
# ======
|
|
||||||
|
|
||||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
|
||||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
|
||||||
# In particular, one must introduce a means to isolate the inference and gradient
|
|
||||||
# calculations on the various GPU's.
|
|
||||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
|
||||||
# A tower is specified by two properties:
|
|
||||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
|
||||||
# is a means to isolate the operations within a tower.
|
|
||||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
|
||||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
|
||||||
# on which all operations within the tower execute.
|
|
||||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
|
||||||
|
|
||||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
|
||||||
r'''
|
|
||||||
With this preliminary step out of the way, we can for each GPU introduce a
|
|
||||||
tower for which's batch we calculate and return the optimization gradients
|
|
||||||
and the average loss across towers.
|
|
||||||
'''
|
|
||||||
# To calculate the mean of the losses
|
|
||||||
tower_avg_losses = []
|
|
||||||
|
|
||||||
# Tower gradients to return
|
|
||||||
tower_gradients = []
|
|
||||||
|
|
||||||
# Aggregate any non finite files in the batches
|
|
||||||
tower_non_finite_files = []
|
|
||||||
|
|
||||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
|
||||||
# Loop over available_devices
|
|
||||||
for i in range(len(Config.available_devices)):
|
|
||||||
# Execute operations of tower i on device i
|
|
||||||
device = Config.available_devices[i]
|
|
||||||
with tf.device(device):
|
|
||||||
# Create a scope for all operations of tower i
|
|
||||||
with tf.name_scope('tower_%d' % i):
|
|
||||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
|
||||||
# batch along with the original batch's labels (Y) of this tower
|
|
||||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
|
||||||
|
|
||||||
# Allow for variables to be re-used by the next tower
|
|
||||||
tfv1.get_variable_scope().reuse_variables()
|
|
||||||
|
|
||||||
# Retain tower's avg losses
|
|
||||||
tower_avg_losses.append(avg_loss)
|
|
||||||
|
|
||||||
# Compute gradients for model parameters using tower's mini-batch
|
|
||||||
gradients = optimizer.compute_gradients(avg_loss)
|
|
||||||
|
|
||||||
# Retain tower's gradients
|
|
||||||
tower_gradients.append(gradients)
|
|
||||||
|
|
||||||
tower_non_finite_files.append(non_finite_files)
|
|
||||||
|
|
||||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
|
||||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
|
||||||
|
|
||||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
|
||||||
|
|
||||||
# Return gradients and the average loss
|
|
||||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
|
||||||
|
|
||||||
|
|
||||||
def average_gradients(tower_gradients):
|
|
||||||
r'''
|
|
||||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
|
||||||
Note also that this code acts as a synchronization point as it requires all
|
|
||||||
GPUs to be finished with their mini-batch before it can run to completion.
|
|
||||||
'''
|
|
||||||
# List of average gradients to return to the caller
|
|
||||||
average_grads = []
|
|
||||||
|
|
||||||
# Run this on cpu_device to conserve GPU memory
|
|
||||||
with tf.device(Config.cpu_device):
|
|
||||||
# Loop over gradient/variable pairs from all towers
|
|
||||||
for grad_and_vars in zip(*tower_gradients):
|
|
||||||
# Introduce grads to store the gradients for the current variable
|
|
||||||
grads = []
|
|
||||||
|
|
||||||
# Loop over the gradients for the current variable
|
|
||||||
for g, _ in grad_and_vars:
|
|
||||||
# Add 0 dimension to the gradients to represent the tower.
|
|
||||||
expanded_g = tf.expand_dims(g, 0)
|
|
||||||
# Append on a 'tower' dimension which we will average over below.
|
|
||||||
grads.append(expanded_g)
|
|
||||||
|
|
||||||
# Average over the 'tower' dimension
|
|
||||||
grad = tf.concat(grads, 0)
|
|
||||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
|
||||||
|
|
||||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
|
||||||
grad_and_var = (grad, grad_and_vars[0][1])
|
|
||||||
|
|
||||||
# Add the current tuple to average_grads
|
|
||||||
average_grads.append(grad_and_var)
|
|
||||||
|
|
||||||
# Return result to caller
|
|
||||||
return average_grads
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
# =======
|
|
||||||
|
|
||||||
def log_variable(variable, gradient=None):
|
|
||||||
r'''
|
|
||||||
We introduce a function for logging a tensor variable's current state.
|
|
||||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
|
||||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
|
||||||
'''
|
|
||||||
name = variable.name.replace(':', '_')
|
|
||||||
mean = tf.reduce_mean(input_tensor=variable)
|
|
||||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
|
||||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
|
||||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
|
||||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
|
||||||
tfv1.summary.histogram(name=name, values=variable)
|
|
||||||
if gradient is not None:
|
|
||||||
if isinstance(gradient, tf.IndexedSlices):
|
|
||||||
grad_values = gradient.values
|
|
||||||
else:
|
|
||||||
grad_values = gradient
|
|
||||||
if grad_values is not None:
|
|
||||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
|
||||||
|
|
||||||
|
|
||||||
def log_grads_and_vars(grads_and_vars):
|
|
||||||
r'''
|
|
||||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
|
||||||
'''
|
|
||||||
for gradient, variable in grads_and_vars:
|
|
||||||
log_variable(variable, gradient=gradient)
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
|
||||||
do_cache_dataset = True
|
|
||||||
|
|
||||||
# pylint: disable=too-many-boolean-expressions
|
|
||||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
|
||||||
FLAGS.data_aug_features_additive > 0 or
|
|
||||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
|
||||||
FLAGS.augmentation_freq_and_time_masking or
|
|
||||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
|
||||||
FLAGS.augmentation_speed_up_std > 0 or
|
|
||||||
FLAGS.augmentation_sparse_warp):
|
|
||||||
do_cache_dataset = False
|
|
||||||
|
|
||||||
exception_box = ExceptionBox()
|
|
||||||
|
|
||||||
# Create training and validation datasets
|
|
||||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
|
||||||
batch_size=FLAGS.train_batch_size,
|
|
||||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
|
||||||
cache_path=FLAGS.feature_cache,
|
|
||||||
train_phase=True,
|
|
||||||
exception_box=exception_box,
|
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
|
||||||
buffering=FLAGS.read_buffer)
|
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
|
||||||
tfv1.data.get_output_shapes(train_set),
|
|
||||||
output_classes=tfv1.data.get_output_classes(train_set))
|
|
||||||
|
|
||||||
# Make initialization ops for switching between the two sets
|
|
||||||
train_init_op = iterator.make_initializer(train_set)
|
|
||||||
|
|
||||||
if FLAGS.dev_files:
|
|
||||||
dev_sources = FLAGS.dev_files.split(',')
|
|
||||||
dev_sets = [create_dataset([source],
|
|
||||||
batch_size=FLAGS.dev_batch_size,
|
|
||||||
train_phase=False,
|
|
||||||
exception_box=exception_box,
|
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
|
||||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
|
||||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
|
||||||
|
|
||||||
# Dropout
|
|
||||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
|
||||||
dropout_feed_dict = {
|
|
||||||
dropout_rates[0]: FLAGS.dropout_rate,
|
|
||||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
|
||||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
|
||||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
|
||||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
|
||||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
|
||||||
}
|
|
||||||
no_dropout_feed_dict = {
|
|
||||||
rate: 0. for rate in dropout_rates
|
|
||||||
}
|
|
||||||
|
|
||||||
# Building the graph
|
|
||||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
|
||||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
|
||||||
optimizer = create_optimizer(learning_rate_var)
|
|
||||||
|
|
||||||
# Enable mixed precision training
|
|
||||||
if FLAGS.automatic_mixed_precision:
|
|
||||||
log_info('Enabling automatic mixed precision training.')
|
|
||||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
|
||||||
|
|
||||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
|
||||||
|
|
||||||
# Average tower gradients across GPUs
|
|
||||||
avg_tower_gradients = average_gradients(gradients)
|
|
||||||
log_grads_and_vars(avg_tower_gradients)
|
|
||||||
|
|
||||||
# global_step is automagically incremented by the optimizer
|
|
||||||
global_step = tfv1.train.get_or_create_global_step()
|
|
||||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
|
||||||
|
|
||||||
# Summaries
|
|
||||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
|
||||||
step_summary_writers = {
|
|
||||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
|
||||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Checkpointing
|
|
||||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
|
||||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
|
||||||
|
|
||||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
|
||||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
|
||||||
|
|
||||||
# Save flags next to checkpoints
|
|
||||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
|
||||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
|
||||||
with open(flags_file, 'w') as fout:
|
|
||||||
fout.write(FLAGS.flags_into_string())
|
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
|
||||||
log_debug('Session opened.')
|
|
||||||
|
|
||||||
# Prevent further graph changes
|
|
||||||
tfv1.get_default_graph().finalize()
|
|
||||||
|
|
||||||
# Load checkpoint or initialize variables
|
|
||||||
if FLAGS.load == 'auto':
|
|
||||||
method_order = ['best', 'last', 'init']
|
|
||||||
else:
|
|
||||||
method_order = [FLAGS.load]
|
|
||||||
load_or_init_graph(session, method_order)
|
|
||||||
|
|
||||||
def run_set(set_name, epoch, init_op, dataset=None):
|
|
||||||
is_train = set_name == 'train'
|
|
||||||
train_op = apply_gradient_op if is_train else []
|
|
||||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
|
||||||
|
|
||||||
total_loss = 0.0
|
|
||||||
step_count = 0
|
|
||||||
|
|
||||||
step_summary_writer = step_summary_writers.get(set_name)
|
|
||||||
checkpoint_time = time.time()
|
|
||||||
|
|
||||||
# Setup progress bar
|
|
||||||
class LossWidget(progressbar.widgets.FormatLabel):
|
|
||||||
def __init__(self):
|
|
||||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
|
||||||
|
|
||||||
def __call__(self, progress, data, **kwargs):
|
|
||||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
|
||||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
|
||||||
|
|
||||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
|
||||||
widgets = [' | ', progressbar.widgets.Timer(),
|
|
||||||
' | Steps: ', progressbar.widgets.Counter(),
|
|
||||||
' | ', LossWidget()]
|
|
||||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
|
||||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
|
||||||
|
|
||||||
# Initialize iterator to the appropriate dataset
|
|
||||||
session.run(init_op)
|
|
||||||
|
|
||||||
# Batch loop
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
_, current_step, batch_loss, problem_files, step_summary = \
|
|
||||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
|
||||||
feed_dict=feed_dict)
|
|
||||||
exception_box.raise_if_set()
|
|
||||||
except tf.errors.InvalidArgumentError as err:
|
|
||||||
if FLAGS.augmentation_sparse_warp:
|
|
||||||
log_info("Ignoring sparse warp error: {}".format(err))
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
except tf.errors.OutOfRangeError:
|
|
||||||
exception_box.raise_if_set()
|
|
||||||
break
|
|
||||||
|
|
||||||
if problem_files.size > 0:
|
|
||||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
|
||||||
log_error('The following files caused an infinite (or NaN) '
|
|
||||||
'loss: {}'.format(','.join(problem_files)))
|
|
||||||
|
|
||||||
total_loss += batch_loss
|
|
||||||
step_count += 1
|
|
||||||
|
|
||||||
pbar.update(step_count)
|
|
||||||
|
|
||||||
step_summary_writer.add_summary(step_summary, current_step)
|
|
||||||
|
|
||||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
|
||||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
|
||||||
checkpoint_time = time.time()
|
|
||||||
|
|
||||||
pbar.finish()
|
|
||||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
|
||||||
return mean_loss, step_count
|
|
||||||
|
|
||||||
log_info('STARTING Optimization')
|
|
||||||
train_start_time = datetime.utcnow()
|
|
||||||
best_dev_loss = float('inf')
|
|
||||||
dev_losses = []
|
|
||||||
epochs_without_improvement = 0
|
|
||||||
try:
|
|
||||||
for epoch in range(FLAGS.epochs):
|
|
||||||
# Training
|
|
||||||
log_progress('Training epoch %d...' % epoch)
|
|
||||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
|
||||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
|
||||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
|
||||||
|
|
||||||
if FLAGS.dev_files:
|
|
||||||
# Validation
|
|
||||||
dev_loss = 0.0
|
|
||||||
total_steps = 0
|
|
||||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
|
||||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
|
||||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
|
||||||
dev_loss += set_loss * steps
|
|
||||||
total_steps += steps
|
|
||||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
|
||||||
|
|
||||||
dev_loss = dev_loss / total_steps
|
|
||||||
dev_losses.append(dev_loss)
|
|
||||||
|
|
||||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
|
||||||
# the improvement has to be greater than FLAGS.es_min_delta
|
|
||||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
|
||||||
epochs_without_improvement += 1
|
|
||||||
else:
|
|
||||||
epochs_without_improvement = 0
|
|
||||||
|
|
||||||
# Save new best model
|
|
||||||
if dev_loss < best_dev_loss:
|
|
||||||
best_dev_loss = dev_loss
|
|
||||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
|
||||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
|
||||||
|
|
||||||
# Early stopping
|
|
||||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
|
||||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
|
||||||
epochs_without_improvement))
|
|
||||||
break
|
|
||||||
|
|
||||||
# Reduce learning rate on plateau
|
|
||||||
if (FLAGS.reduce_lr_on_plateau and
|
|
||||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
|
||||||
# If the learning rate was reduced and there is still no improvement
|
|
||||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
|
||||||
session.run(reduce_learning_rate_op)
|
|
||||||
current_learning_rate = learning_rate_var.eval()
|
|
||||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
|
||||||
current_learning_rate))
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
|
||||||
log_debug('Session closed.')
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
|
||||||
if FLAGS.test_output_file:
|
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
|
||||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|
||||||
batch_size = batch_size if batch_size > 0 else None
|
|
||||||
|
|
||||||
# Create feature computation graph
|
|
||||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
|
||||||
samples = tf.expand_dims(input_samples, -1)
|
|
||||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
|
||||||
mfccs = tf.identity(mfccs, name='mfccs')
|
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
|
||||||
# This shape is read by the native_client in DS_CreateModel to know the
|
|
||||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
|
||||||
# there if this shape is changed.
|
|
||||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
|
||||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
|
||||||
|
|
||||||
if batch_size <= 0:
|
|
||||||
# no state management since n_step is expected to be dynamic too (see below)
|
|
||||||
previous_state = None
|
|
||||||
else:
|
|
||||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
|
||||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
|
||||||
|
|
||||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
|
||||||
|
|
||||||
# One rate per layer
|
|
||||||
no_dropout = [None] * 6
|
|
||||||
|
|
||||||
if tflite:
|
|
||||||
rnn_impl = rnn_impl_static_rnn
|
|
||||||
else:
|
|
||||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
|
||||||
|
|
||||||
logits, layers = create_model(batch_x=input_tensor,
|
|
||||||
batch_size=batch_size,
|
|
||||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
|
||||||
dropout=no_dropout,
|
|
||||||
previous_state=previous_state,
|
|
||||||
overlap=False,
|
|
||||||
rnn_impl=rnn_impl)
|
|
||||||
|
|
||||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
|
||||||
# by default we get 3, the middle one being batch_size which is forced to
|
|
||||||
# one on inference graph, so remove that dimension
|
|
||||||
if tflite:
|
|
||||||
logits = tf.squeeze(logits, [1])
|
|
||||||
|
|
||||||
# Apply softmax for CTC decoder
|
|
||||||
logits = tf.nn.softmax(logits, name='logits')
|
|
||||||
|
|
||||||
if batch_size <= 0:
|
|
||||||
if tflite:
|
|
||||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
|
||||||
if n_steps > 0:
|
|
||||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
|
||||||
return (
|
|
||||||
{
|
|
||||||
'input': input_tensor,
|
|
||||||
'input_lengths': seq_length,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'outputs': logits,
|
|
||||||
},
|
|
||||||
layers
|
|
||||||
)
|
|
||||||
|
|
||||||
new_state_c, new_state_h = layers['rnn_output_state']
|
|
||||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
|
||||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
'input': input_tensor,
|
|
||||||
'previous_state_c': previous_state_c,
|
|
||||||
'previous_state_h': previous_state_h,
|
|
||||||
'input_samples': input_samples,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
|
||||||
inputs['input_lengths'] = seq_length
|
|
||||||
|
|
||||||
outputs = {
|
|
||||||
'outputs': logits,
|
|
||||||
'new_state_c': new_state_c,
|
|
||||||
'new_state_h': new_state_h,
|
|
||||||
'mfccs': mfccs,
|
|
||||||
}
|
|
||||||
|
|
||||||
return inputs, outputs, layers
|
|
||||||
|
|
||||||
|
|
||||||
def file_relative_read(fname):
|
|
||||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
|
||||||
|
|
||||||
|
|
||||||
def export():
|
|
||||||
r'''
|
|
||||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
|
||||||
'''
|
|
||||||
log_info('Exporting the model...')
|
|
||||||
from tensorflow.python.framework.ops import Tensor, Operation
|
|
||||||
|
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
|
||||||
|
|
||||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
|
||||||
assert graph_version > 0
|
|
||||||
|
|
||||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
|
||||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
|
||||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
|
||||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
|
||||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
|
||||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
|
||||||
|
|
||||||
if FLAGS.export_language:
|
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
|
||||||
|
|
||||||
# Prevent further graph changes
|
|
||||||
tfv1.get_default_graph().finalize()
|
|
||||||
|
|
||||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
|
||||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
|
||||||
output_names = output_names_tensors + output_names_ops
|
|
||||||
|
|
||||||
with tf.Session() as session:
|
|
||||||
# Restore variables from checkpoint
|
|
||||||
if FLAGS.load == 'auto':
|
|
||||||
method_order = ['best', 'last']
|
|
||||||
else:
|
|
||||||
method_order = [FLAGS.load]
|
|
||||||
load_or_init_graph(session, method_order)
|
|
||||||
|
|
||||||
output_filename = FLAGS.export_file_name + '.pb'
|
|
||||||
if FLAGS.remove_export:
|
|
||||||
if os.path.isdir(FLAGS.export_dir):
|
|
||||||
log_info('Removing old export')
|
|
||||||
shutil.rmtree(FLAGS.export_dir)
|
|
||||||
|
|
||||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
|
||||||
|
|
||||||
if not os.path.isdir(FLAGS.export_dir):
|
|
||||||
os.makedirs(FLAGS.export_dir)
|
|
||||||
|
|
||||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
|
||||||
sess=session,
|
|
||||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
|
||||||
output_node_names=output_names)
|
|
||||||
|
|
||||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
|
||||||
graph_def=frozen_graph,
|
|
||||||
dest_nodes=output_names)
|
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
|
||||||
with open(output_graph_path, 'wb') as fout:
|
|
||||||
fout.write(frozen_graph.SerializeToString())
|
|
||||||
else:
|
|
||||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
|
||||||
|
|
||||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
|
||||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
||||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
|
||||||
converter.allow_custom_ops = True
|
|
||||||
tflite_model = converter.convert()
|
|
||||||
|
|
||||||
with open(output_tflite_path, 'wb') as fout:
|
|
||||||
fout.write(tflite_model)
|
|
||||||
|
|
||||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
|
||||||
|
|
||||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
|
||||||
FLAGS.export_author_id,
|
|
||||||
FLAGS.export_model_name,
|
|
||||||
FLAGS.export_model_version))
|
|
||||||
|
|
||||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
|
||||||
with open(metadata_fname, 'w') as f:
|
|
||||||
f.write('---\n')
|
|
||||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
|
||||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
|
||||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
|
||||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
|
||||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
|
||||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
|
||||||
f.write('runtime: {}\n'.format(model_runtime))
|
|
||||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
|
||||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
|
||||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
|
||||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
|
||||||
f.write('---\n')
|
|
||||||
f.write('{}\n'.format(FLAGS.export_description))
|
|
||||||
|
|
||||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
|
||||||
|
|
||||||
|
|
||||||
def package_zip():
|
|
||||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
|
||||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
|
||||||
zip_filename = os.path.dirname(export_dir)
|
|
||||||
|
|
||||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
|
||||||
|
|
||||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
|
||||||
log_info('Exported packaged model {}'.format(archive))
|
|
||||||
|
|
||||||
|
|
||||||
def do_single_file_inference(input_file_path):
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
|
||||||
if FLAGS.load == 'auto':
|
|
||||||
method_order = ['best', 'last']
|
|
||||||
else:
|
|
||||||
method_order = [FLAGS.load]
|
|
||||||
load_or_init_graph(session, method_order)
|
|
||||||
|
|
||||||
features, features_len = audiofile_to_features(input_file_path)
|
|
||||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
|
||||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
|
||||||
|
|
||||||
# Add batch dimension
|
|
||||||
features = tf.expand_dims(features, 0)
|
|
||||||
features_len = tf.expand_dims(features_len, 0)
|
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
features = create_overlapping_windows(features).eval(session=session)
|
|
||||||
features_len = features_len.eval(session=session)
|
|
||||||
|
|
||||||
logits = outputs['outputs'].eval(feed_dict={
|
|
||||||
inputs['input']: features,
|
|
||||||
inputs['input_lengths']: features_len,
|
|
||||||
inputs['previous_state_c']: previous_state_c,
|
|
||||||
inputs['previous_state_h']: previous_state_h,
|
|
||||||
}, session=session)
|
|
||||||
|
|
||||||
logits = np.squeeze(logits)
|
|
||||||
|
|
||||||
if FLAGS.scorer_path:
|
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
|
||||||
FLAGS.scorer_path, Config.alphabet)
|
|
||||||
else:
|
|
||||||
scorer = None
|
|
||||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
|
||||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
|
||||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
|
||||||
# Print highest probability result
|
|
||||||
print(decoded[0][1])
|
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
|
||||||
initialize_globals()
|
|
||||||
|
|
||||||
if FLAGS.train_files:
|
|
||||||
tfv1.reset_default_graph()
|
|
||||||
tfv1.set_random_seed(FLAGS.random_seed)
|
|
||||||
train()
|
|
||||||
|
|
||||||
if FLAGS.test_files:
|
|
||||||
tfv1.reset_default_graph()
|
|
||||||
test()
|
|
||||||
|
|
||||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
|
||||||
tfv1.reset_default_graph()
|
|
||||||
export()
|
|
||||||
|
|
||||||
if FLAGS.export_zip:
|
|
||||||
tfv1.reset_default_graph()
|
|
||||||
FLAGS.export_tflite = True
|
|
||||||
|
|
||||||
if os.listdir(FLAGS.export_dir):
|
|
||||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
export()
|
|
||||||
package_zip()
|
|
||||||
|
|
||||||
if FLAGS.one_shot_infer:
|
|
||||||
tfv1.reset_default_graph()
|
|
||||||
do_single_file_inference(FLAGS.one_shot_infer)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
create_flags()
|
try:
|
||||||
absl.app.run(main)
|
from deepspeech_training import train as ds_train
|
||||||
|
except ImportError:
|
||||||
|
print('Training package is not installed. See training documentation.')
|
||||||
|
raise
|
||||||
|
|
||||||
|
ds_train.run_script()
|
||||||
|
|
|
@ -150,7 +150,7 @@ COPY . /DeepSpeech/
|
||||||
|
|
||||||
WORKDIR /DeepSpeech
|
WORKDIR /DeepSpeech
|
||||||
|
|
||||||
RUN pip3 --no-cache-dir install -r requirements.txt
|
RUN pip3 --no-cache-dir install .
|
||||||
|
|
||||||
# Link DeepSpeech native_client libs to tf folder
|
# Link DeepSpeech native_client libs to tf folder
|
||||||
RUN ln -s /DeepSpeech/native_client /tensorflow
|
RUN ln -s /DeepSpeech/native_client /tensorflow
|
||||||
|
|
|
@ -5,18 +5,12 @@ Use "python3 build_sdb.py -h" for help
|
||||||
'''
|
'''
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import progressbar
|
import progressbar
|
||||||
|
|
||||||
from util.downloader import SIMPLE_BAR
|
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||||
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
from deepspeech_training.util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
||||||
from util.sample_collections import samples_from_files, DirectSDBWriter
|
from deepspeech_training.util.sample_collections import samples_from_files, DirectSDBWriter
|
||||||
|
|
||||||
AUDIO_TYPE_LOOKUP = {
|
AUDIO_TYPE_LOOKUP = {
|
||||||
'wav': AUDIO_TYPE_WAV,
|
'wav': AUDIO_TYPE_WAV,
|
||||||
|
|
|
@ -1,11 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
from deepspeech_training.util.gpu_usage import GPUUsage
|
||||||
import sys
|
|
||||||
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.abspath('.'))
|
|
||||||
|
|
||||||
from util.gpu_usage import GPUUsage
|
|
||||||
|
|
||||||
gu = GPUUsage()
|
gu = GPUUsage()
|
||||||
gu.start()
|
gu.start()
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import os
|
from deepspeech_training.util.gpu_usage import GPUUsageChart
|
||||||
sys.path.append(os.path.abspath('.'))
|
|
||||||
|
|
||||||
from util.gpu_usage import GPUUsageChart
|
|
||||||
|
|
||||||
GPUUsageChart(sys.argv[1], sys.argv[2])
|
GPUUsageChart(sys.argv[1], sys.argv[2])
|
||||||
|
|
|
@ -1,17 +1,12 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser
|
|
||||||
import glob
|
import glob
|
||||||
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser
|
||||||
|
|
||||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,12 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser
|
|
||||||
import glob
|
import glob
|
||||||
import tarfile
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser
|
||||||
|
|
||||||
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,17 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import sox
|
import os
|
||||||
import tarfile
|
|
||||||
import subprocess
|
|
||||||
import progressbar
|
import progressbar
|
||||||
|
import sox
|
||||||
|
import subprocess
|
||||||
|
import tarfile
|
||||||
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from os import path
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
|
from deepspeech_training.util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
|
||||||
from util.downloader import maybe_download, SIMPLE_BAR
|
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||||
|
|
||||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
@ -28,7 +22,7 @@ ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' +
|
||||||
|
|
||||||
def _download_and_preprocess_data(target_dir):
|
def _download_and_preprocess_data(target_dir):
|
||||||
# Making path absolute
|
# Making path absolute
|
||||||
target_dir = path.abspath(target_dir)
|
target_dir = os.path.abspath(target_dir)
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||||
# Conditionally extract common voice data
|
# Conditionally extract common voice data
|
||||||
|
@ -38,8 +32,8 @@ def _download_and_preprocess_data(target_dir):
|
||||||
|
|
||||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||||
extracted_path = path.join(target_dir, extracted_data)
|
extracted_path = os.join(target_dir, extracted_data)
|
||||||
if not path.exists(extracted_path):
|
if not os.path.exists(extracted_path):
|
||||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||||
with tarfile.open(archive_path) as tar:
|
with tarfile.open(archive_path) as tar:
|
||||||
tar.extractall(target_dir)
|
tar.extractall(target_dir)
|
||||||
|
@ -47,9 +41,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
|
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
|
||||||
|
|
||||||
def _maybe_convert_sets(target_dir, extracted_data):
|
def _maybe_convert_sets(target_dir, extracted_data):
|
||||||
extracted_dir = path.join(target_dir, extracted_data)
|
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||||
for source_csv in glob(path.join(extracted_dir, '*.csv')):
|
for source_csv in glob(os.path.join(extracted_dir, '*.csv')):
|
||||||
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
|
_maybe_convert_set(extracted_dir, source_csv, os.path.join(target_dir, os.path.split(source_csv)[-1]))
|
||||||
|
|
||||||
def one_sample(sample):
|
def one_sample(sample):
|
||||||
mp3_filename = sample[0]
|
mp3_filename = sample[0]
|
||||||
|
@ -58,7 +52,7 @@ def one_sample(sample):
|
||||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
file_size = -1
|
file_size = -1
|
||||||
if path.exists(wav_filename):
|
if os.path.exists(wav_filename):
|
||||||
file_size = path.getsize(wav_filename)
|
file_size = path.getsize(wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
label = validate_label(sample[1])
|
label = validate_label(sample[1])
|
||||||
|
@ -85,7 +79,7 @@ def one_sample(sample):
|
||||||
|
|
||||||
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||||
print()
|
print()
|
||||||
if path.exists(target_csv):
|
if os.path.exists(target_csv):
|
||||||
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
|
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
|
||||||
return
|
return
|
||||||
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
|
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
|
||||||
|
@ -126,7 +120,7 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||||
|
|
||||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||||
if not path.exists(wav_filename):
|
if not os.path.exists(wav_filename):
|
||||||
transformer = sox.Transformer()
|
transformer = sox.Transformer()
|
||||||
transformer.convert(samplerate=SAMPLE_RATE)
|
transformer.convert(samplerate=SAMPLE_RATE)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -8,23 +8,17 @@ Use "python3 import_cv2.py -h" for help
|
||||||
'''
|
'''
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
|
import os
|
||||||
|
import progressbar
|
||||||
import sox
|
import sox
|
||||||
import subprocess
|
import subprocess
|
||||||
import progressbar
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
from os import path
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from util.downloader import SIMPLE_BAR
|
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||||
from util.text import Alphabet
|
from deepspeech_training.util.text import Alphabet
|
||||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||||
|
|
||||||
|
|
||||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
@ -34,7 +28,7 @@ MAX_SECS = 10
|
||||||
|
|
||||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||||
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
|
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
|
||||||
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
|
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset+".tsv")
|
||||||
if os.path.isfile(input_tsv):
|
if os.path.isfile(input_tsv):
|
||||||
print("Loading TSV file: ", input_tsv)
|
print("Loading TSV file: ", input_tsv)
|
||||||
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
|
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
|
||||||
|
@ -42,15 +36,15 @@ def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||||
def one_sample(sample):
|
def one_sample(sample):
|
||||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||||
mp3_filename = sample[0]
|
mp3_filename = sample[0]
|
||||||
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
|
if not os.path.splitext(mp3_filename.lower())[1] == '.mp3':
|
||||||
mp3_filename += ".mp3"
|
mp3_filename += ".mp3"
|
||||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
|
||||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||||
file_size = -1
|
file_size = -1
|
||||||
frames = 0
|
frames = 0
|
||||||
if path.exists(wav_filename):
|
if os.path.exists(wav_filename):
|
||||||
file_size = path.getsize(wav_filename)
|
file_size = os.path.getsize(wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
label = label_filter_fun(sample[1])
|
label = label_filter_fun(sample[1])
|
||||||
rows = []
|
rows = []
|
||||||
|
@ -76,7 +70,7 @@ def one_sample(sample):
|
||||||
return (counter, rows)
|
return (counter, rows)
|
||||||
|
|
||||||
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||||
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
output_csv = os.path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||||
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
||||||
|
|
||||||
# Get audiofile path and transcript for each sentence in tsv
|
# Get audiofile path and transcript for each sentence in tsv
|
||||||
|
@ -84,7 +78,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||||
with open(input_tsv, encoding='utf-8') as input_tsv_file:
|
with open(input_tsv, encoding='utf-8') as input_tsv_file:
|
||||||
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
||||||
for row in reader:
|
for row in reader:
|
||||||
samples.append((path.join(audio_dir, row['path']), row['sentence']))
|
samples.append((os.path.join(audio_dir, row['path']), row['sentence']))
|
||||||
|
|
||||||
counter = get_counter()
|
counter = get_counter()
|
||||||
num_samples = len(samples)
|
num_samples = len(samples)
|
||||||
|
@ -120,7 +114,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||||
if not path.exists(wav_filename):
|
if not os.path.exists(wav_filename):
|
||||||
transformer = sox.Transformer()
|
transformer = sox.Transformer()
|
||||||
transformer.convert(samplerate=SAMPLE_RATE)
|
transformer.convert(samplerate=SAMPLE_RATE)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -4,22 +4,17 @@ from __future__ import absolute_import, division, print_function
|
||||||
# Prerequisite: Having the sph2pipe tool in your PATH:
|
# Prerequisite: Having the sph2pipe tool in your PATH:
|
||||||
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
|
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import librosa
|
||||||
import os
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
import subprocess
|
|
||||||
import unicodedata
|
|
||||||
import librosa
|
|
||||||
import soundfile # <= Has an external dependency on libsndfile
|
import soundfile # <= Has an external dependency on libsndfile
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
from util.importers import validate_label_eng as validate_label
|
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||||
|
|
||||||
def _download_and_preprocess_data(data_dir):
|
def _download_and_preprocess_data(data_dir):
|
||||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||||
|
|
|
@ -1,18 +1,13 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser
|
|
||||||
import glob
|
import glob
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser
|
||||||
|
|
||||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import math
|
|
||||||
import urllib
|
|
||||||
import logging
|
import logging
|
||||||
from util.importers import get_importers_parser, get_validate_label
|
import math
|
||||||
import subprocess
|
import os
|
||||||
from os import path
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import swifter
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import swifter
|
||||||
|
import subprocess
|
||||||
|
import urllib
|
||||||
|
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
|
||||||
|
from pathlib import Path
|
||||||
from sox import Transformer
|
from sox import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,11 +136,11 @@ class GramVaaniDownloader:
|
||||||
return mp3_directory
|
return mp3_directory
|
||||||
|
|
||||||
def _pre_download(self):
|
def _pre_download(self):
|
||||||
mp3_directory = path.join(self.target_dir, "mp3")
|
mp3_directory = os.path.join(self.target_dir, "mp3")
|
||||||
if not path.exists(self.target_dir):
|
if not os.path.exists(self.target_dir):
|
||||||
_logger.info("Creating directory...%s", self.target_dir)
|
_logger.info("Creating directory...%s", self.target_dir)
|
||||||
os.mkdir(self.target_dir)
|
os.mkdir(self.target_dir)
|
||||||
if not path.exists(mp3_directory):
|
if not os.path.exists(mp3_directory):
|
||||||
_logger.info("Creating directory...%s", mp3_directory)
|
_logger.info("Creating directory...%s", mp3_directory)
|
||||||
os.mkdir(mp3_directory)
|
os.mkdir(mp3_directory)
|
||||||
return mp3_directory
|
return mp3_directory
|
||||||
|
@ -154,8 +148,8 @@ class GramVaaniDownloader:
|
||||||
def _download(self, audio_url, transcript, audio_length, mp3_directory):
|
def _download(self, audio_url, transcript, audio_length, mp3_directory):
|
||||||
if audio_url == "audio_url":
|
if audio_url == "audio_url":
|
||||||
return
|
return
|
||||||
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
|
mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
|
||||||
if not path.exists(mp3_filename):
|
if not os.path.exists(mp3_filename):
|
||||||
_logger.debug("Downloading mp3 file...%s", audio_url)
|
_logger.debug("Downloading mp3 file...%s", audio_url)
|
||||||
urllib.request.urlretrieve(audio_url, mp3_filename)
|
urllib.request.urlretrieve(audio_url, mp3_filename)
|
||||||
else:
|
else:
|
||||||
|
@ -182,8 +176,8 @@ class GramVaaniConverter:
|
||||||
"""
|
"""
|
||||||
wav_directory = self._pre_convert()
|
wav_directory = self._pre_convert()
|
||||||
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
||||||
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
wav_filename = os.path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||||
if not path.exists(wav_filename):
|
if not os.path.exists(wav_filename):
|
||||||
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||||
transformer = Transformer()
|
transformer = Transformer()
|
||||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
||||||
|
@ -193,11 +187,11 @@ class GramVaaniConverter:
|
||||||
return wav_directory
|
return wav_directory
|
||||||
|
|
||||||
def _pre_convert(self):
|
def _pre_convert(self):
|
||||||
wav_directory = path.join(self.target_dir, "wav")
|
wav_directory = os.path.join(self.target_dir, "wav")
|
||||||
if not path.exists(self.target_dir):
|
if not os.path.exists(self.target_dir):
|
||||||
_logger.info("Creating directory...%s", self.target_dir)
|
_logger.info("Creating directory...%s", self.target_dir)
|
||||||
os.mkdir(self.target_dir)
|
os.mkdir(self.target_dir)
|
||||||
if not path.exists(wav_directory):
|
if not os.path.exists(wav_directory):
|
||||||
_logger.info("Creating directory...%s", wav_directory)
|
_logger.info("Creating directory...%s", wav_directory)
|
||||||
os.mkdir(wav_directory)
|
os.mkdir(wav_directory)
|
||||||
return wav_directory
|
return wav_directory
|
||||||
|
@ -233,8 +227,8 @@ class GramVaaniDataSets:
|
||||||
if audio_url == "audio_url":
|
if audio_url == "audio_url":
|
||||||
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
||||||
mp3_filename = os.path.basename(audio_url)
|
mp3_filename = os.path.basename(audio_url)
|
||||||
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
wav_relative_filename = os.path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||||
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
|
wav_filesize = os.path.getsize(os.path.join(self.target_dir, wav_relative_filename))
|
||||||
transcript = validate_label(transcript)
|
transcript = validate_label(transcript)
|
||||||
if None == transcript:
|
if None == transcript:
|
||||||
transcript = ""
|
transcript = ""
|
||||||
|
@ -252,7 +246,7 @@ class GramVaaniDataSets:
|
||||||
|
|
||||||
def _is_valid_raw_wav_frames(self):
|
def _is_valid_raw_wav_frames(self):
|
||||||
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
||||||
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
wav_filepaths = [os.path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||||
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
||||||
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
||||||
return pd.Series(is_valid_raw_wav_frames)
|
return pd.Series(is_valid_raw_wav_frames)
|
||||||
|
|
|
@ -1,15 +1,11 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from util.downloader import maybe_download
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
|
|
||||||
def _download_and_preprocess_data(data_dir):
|
def _download_and_preprocess_data(data_dir):
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
|
|
|
@ -1,22 +1,18 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
import progressbar
|
import progressbar
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
from sox import Transformer
|
from sox import Transformer
|
||||||
from util.downloader import maybe_download
|
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
|
|
@ -1,31 +1,23 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
|
import os
|
||||||
|
import progressbar
|
||||||
import re
|
import re
|
||||||
import sox
|
import sox
|
||||||
import zipfile
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import progressbar
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
import zipfile
|
||||||
|
|
||||||
from multiprocessing import Pool
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
from util.downloader import SIMPLE_BAR
|
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||||
from os import path
|
from deepspeech_training.util.text import Alphabet
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
from util.downloader import maybe_download
|
|
||||||
from util.text import Alphabet
|
|
||||||
|
|
||||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
@ -38,7 +30,7 @@ ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
|
||||||
|
|
||||||
def _download_and_preprocess_data(target_dir):
|
def _download_and_preprocess_data(target_dir):
|
||||||
# Making path absolute
|
# Making path absolute
|
||||||
target_dir = path.abspath(target_dir)
|
target_dir = os.path.abspath(target_dir)
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||||
# Conditionally extract data
|
# Conditionally extract data
|
||||||
|
@ -48,8 +40,8 @@ def _download_and_preprocess_data(target_dir):
|
||||||
|
|
||||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||||
extracted_path = path.join(target_dir, extracted_data)
|
extracted_path = os.path.join(target_dir, extracted_data)
|
||||||
if not path.exists(extracted_path):
|
if not os.path.exists(extracted_path):
|
||||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||||
if not os.path.isdir(extracted_path):
|
if not os.path.isdir(extracted_path):
|
||||||
os.mkdir(extracted_path)
|
os.mkdir(extracted_path)
|
||||||
|
@ -62,12 +54,12 @@ def one_sample(sample):
|
||||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||||
ogg_filename = sample[0]
|
ogg_filename = sample[0]
|
||||||
# Storing wav files next to the ogg ones - just with a different suffix
|
# Storing wav files next to the ogg ones - just with a different suffix
|
||||||
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
|
wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
|
||||||
_maybe_convert_wav(ogg_filename, wav_filename)
|
_maybe_convert_wav(ogg_filename, wav_filename)
|
||||||
file_size = -1
|
file_size = -1
|
||||||
frames = 0
|
frames = 0
|
||||||
if path.exists(wav_filename):
|
if os.path.exists(wav_filename):
|
||||||
file_size = path.getsize(wav_filename)
|
file_size = os.path.getsize(wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
label = label_filter(sample[1])
|
label = label_filter(sample[1])
|
||||||
rows = []
|
rows = []
|
||||||
|
@ -94,7 +86,7 @@ def one_sample(sample):
|
||||||
return (counter, rows)
|
return (counter, rows)
|
||||||
|
|
||||||
def _maybe_convert_sets(target_dir, extracted_data):
|
def _maybe_convert_sets(target_dir, extracted_data):
|
||||||
extracted_dir = path.join(target_dir, extracted_data)
|
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||||
# override existing CSV with normalized one
|
# override existing CSV with normalized one
|
||||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
|
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
|
||||||
if os.path.isfile(target_csv_template):
|
if os.path.isfile(target_csv_template):
|
||||||
|
@ -160,7 +152,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||||
|
|
||||||
def _maybe_convert_wav(ogg_filename, wav_filename):
|
def _maybe_convert_wav(ogg_filename, wav_filename):
|
||||||
if not path.exists(wav_filename):
|
if not os.path.exists(wav_filename):
|
||||||
transformer = sox.Transformer()
|
transformer = sox.Transformer()
|
||||||
transformer.convert(samplerate=SAMPLE_RATE)
|
transformer.convert(samplerate=SAMPLE_RATE)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -2,29 +2,19 @@
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import subprocess
|
import os
|
||||||
import progressbar
|
import progressbar
|
||||||
import unicodedata
|
import subprocess
|
||||||
import tarfile
|
import tarfile
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
from multiprocessing import Pool
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
from util.downloader import SIMPLE_BAR
|
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||||
from os import path
|
from deepspeech_training.util.text import Alphabet
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from multiprocessing import Pool
|
||||||
from util.downloader import maybe_download
|
|
||||||
from util.text import Alphabet
|
|
||||||
|
|
||||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
@ -37,7 +27,7 @@ ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
|
||||||
|
|
||||||
def _download_and_preprocess_data(target_dir):
|
def _download_and_preprocess_data(target_dir):
|
||||||
# Making path absolute
|
# Making path absolute
|
||||||
target_dir = path.abspath(target_dir)
|
target_dir = os.path.abspath(target_dir)
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||||
# Conditionally extract data
|
# Conditionally extract data
|
||||||
|
@ -48,8 +38,8 @@ def _download_and_preprocess_data(target_dir):
|
||||||
|
|
||||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||||
extracted_path = path.join(target_dir, extracted_data)
|
extracted_path = os.path.join(target_dir, extracted_data)
|
||||||
if not path.exists(extracted_path):
|
if not os.path.exists(extracted_path):
|
||||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||||
if not os.path.isdir(extracted_path):
|
if not os.path.isdir(extracted_path):
|
||||||
os.mkdir(extracted_path)
|
os.mkdir(extracted_path)
|
||||||
|
@ -65,8 +55,8 @@ def one_sample(sample):
|
||||||
wav_filename = sample[0]
|
wav_filename = sample[0]
|
||||||
file_size = -1
|
file_size = -1
|
||||||
frames = 0
|
frames = 0
|
||||||
if path.exists(wav_filename):
|
if os.path.exists(wav_filename):
|
||||||
file_size = path.getsize(wav_filename)
|
file_size = os.path.getsize(wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
label = label_filter(sample[1])
|
label = label_filter(sample[1])
|
||||||
counter = get_counter()
|
counter = get_counter()
|
||||||
|
@ -93,7 +83,7 @@ def one_sample(sample):
|
||||||
return (counter, rows)
|
return (counter, rows)
|
||||||
|
|
||||||
def _maybe_convert_sets(target_dir, extracted_data):
|
def _maybe_convert_sets(target_dir, extracted_data):
|
||||||
extracted_dir = path.join(target_dir, extracted_data)
|
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||||
# override existing CSV with normalized one
|
# override existing CSV with normalized one
|
||||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
|
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
|
||||||
if os.path.isfile(target_csv_template):
|
if os.path.isfile(target_csv_template):
|
||||||
|
|
|
@ -1,18 +1,13 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser
|
|
||||||
import glob
|
import glob
|
||||||
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
import tarfile
|
import tarfile
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser
|
||||||
|
|
||||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
|
||||||
|
|
|
@ -1,19 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser
|
||||||
|
|
||||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
|
||||||
|
|
|
@ -1,32 +1,22 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
|
import os
|
||||||
|
import progressbar
|
||||||
import re
|
import re
|
||||||
import sox
|
import sox
|
||||||
import zipfile
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import progressbar
|
|
||||||
import unicodedata
|
|
||||||
import tarfile
|
import tarfile
|
||||||
|
import unicodedata
|
||||||
|
import zipfile
|
||||||
|
|
||||||
from multiprocessing import Pool
|
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||||
from util.downloader import SIMPLE_BAR
|
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||||
|
from deepspeech_training.util.text import Alphabet
|
||||||
from os import path
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
from util.downloader import maybe_download
|
|
||||||
from util.text import Alphabet
|
|
||||||
from util.helpers import secs_to_hours
|
|
||||||
|
|
||||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
@ -39,7 +29,7 @@ ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
|
||||||
|
|
||||||
def _download_and_preprocess_data(target_dir):
|
def _download_and_preprocess_data(target_dir):
|
||||||
# Making path absolute
|
# Making path absolute
|
||||||
target_dir = path.abspath(target_dir)
|
target_dir = os.path.abspath(target_dir)
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||||
# Conditionally extract data
|
# Conditionally extract data
|
||||||
|
@ -49,8 +39,8 @@ def _download_and_preprocess_data(target_dir):
|
||||||
|
|
||||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||||
extracted_path = path.join(target_dir, extracted_data)
|
extracted_path = os.path.join(target_dir, extracted_data)
|
||||||
if not path.exists(extracted_path):
|
if not os.path.exists(extracted_path):
|
||||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||||
if not os.path.isdir(extracted_path):
|
if not os.path.isdir(extracted_path):
|
||||||
os.mkdir(extracted_path)
|
os.mkdir(extracted_path)
|
||||||
|
@ -65,8 +55,8 @@ def one_sample(sample):
|
||||||
wav_filename = sample[0]
|
wav_filename = sample[0]
|
||||||
file_size = -1
|
file_size = -1
|
||||||
frames = 0
|
frames = 0
|
||||||
if path.exists(wav_filename):
|
if os.path.exists(wav_filename):
|
||||||
file_size = path.getsize(wav_filename)
|
file_size = os.path.getsize(wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
label = label_filter(sample[1])
|
label = label_filter(sample[1])
|
||||||
counter = get_counter()
|
counter = get_counter()
|
||||||
|
@ -92,7 +82,7 @@ def one_sample(sample):
|
||||||
return (counter, rows)
|
return (counter, rows)
|
||||||
|
|
||||||
def _maybe_convert_sets(target_dir, extracted_data):
|
def _maybe_convert_sets(target_dir, extracted_data):
|
||||||
extracted_dir = path.join(target_dir, extracted_data)
|
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||||
# override existing CSV with normalized one
|
# override existing CSV with normalized one
|
||||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
|
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
|
||||||
if os.path.isfile(target_csv_template):
|
if os.path.isfile(target_csv_template):
|
||||||
|
|
|
@ -1,28 +1,24 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
|
|
||||||
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
|
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
|
||||||
# ./data/swb/swb1_LDC97S62.tgz
|
# ./data/swb/swb1_LDC97S62.tgz
|
||||||
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
|
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
|
||||||
|
import codecs
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import librosa
|
||||||
|
import os
|
||||||
import pandas
|
import pandas
|
||||||
|
import requests
|
||||||
|
import soundfile # <= Has an external dependency on libsndfile
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import wave
|
import wave
|
||||||
import codecs
|
|
||||||
import tarfile
|
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||||
import requests
|
|
||||||
from util.importers import validate_label_eng as validate_label
|
|
||||||
import librosa
|
|
||||||
import soundfile # <= Has an external dependency on libsndfile
|
|
||||||
|
|
||||||
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
|
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
|
||||||
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
|
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
|
||||||
|
|
|
@ -5,31 +5,26 @@ Use "python3 import_swc.py -h" for help
|
||||||
'''
|
'''
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import re
|
|
||||||
import csv
|
|
||||||
import sox
|
|
||||||
import wave
|
|
||||||
import shutil
|
|
||||||
import random
|
|
||||||
import tarfile
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
import progressbar
|
import progressbar
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import sox
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
import wave
|
||||||
import xml.etree.cElementTree as ET
|
import xml.etree.cElementTree as ET
|
||||||
|
|
||||||
from os import path
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from util.text import Alphabet
|
from deepspeech_training.util.text import Alphabet
|
||||||
from util.importers import validate_label_eng as validate_label
|
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||||
from util.downloader import maybe_download, SIMPLE_BAR
|
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||||
|
|
||||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||||
|
@ -117,8 +112,8 @@ def maybe_download_language(language):
|
||||||
|
|
||||||
|
|
||||||
def maybe_extract(data_dir, extracted_data, archive):
|
def maybe_extract(data_dir, extracted_data, archive):
|
||||||
extracted = path.join(data_dir, extracted_data)
|
extracted = os.path.join(data_dir, extracted_data)
|
||||||
if path.isdir(extracted):
|
if os.path.isdir(extracted):
|
||||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||||
else:
|
else:
|
||||||
print('Extracting "{}"...'.format(archive))
|
print('Extracting "{}"...'.format(archive))
|
||||||
|
@ -242,7 +237,7 @@ def collect_samples(base_dir, language):
|
||||||
print('Collecting samples...')
|
print('Collecting samples...')
|
||||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||||
for root in bar(roots):
|
for root in bar(roots):
|
||||||
wav_path = path.join(root, WAV_NAME)
|
wav_path = os.path.join(root, WAV_NAME)
|
||||||
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
||||||
article = UNKNOWN
|
article = UNKNOWN
|
||||||
speaker = UNKNOWN
|
speaker = UNKNOWN
|
||||||
|
@ -294,8 +289,8 @@ def maybe_convert_one_to_wav(entry):
|
||||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||||
combiner = sox.Combiner()
|
combiner = sox.Combiner()
|
||||||
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||||
output_wav = path.join(root, WAV_NAME)
|
output_wav = os.path.join(root, WAV_NAME)
|
||||||
if path.isfile(output_wav):
|
if os.path.isfile(output_wav):
|
||||||
return
|
return
|
||||||
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
|
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
|
||||||
try:
|
try:
|
||||||
|
@ -304,7 +299,7 @@ def maybe_convert_one_to_wav(entry):
|
||||||
elif len(files) > 1:
|
elif len(files) > 1:
|
||||||
wav_files = []
|
wav_files = []
|
||||||
for i, file in enumerate(files):
|
for i, file in enumerate(files):
|
||||||
wav_path = path.join(root, 'audio{}.wav'.format(i))
|
wav_path = os.path.join(root, 'audio{}.wav'.format(i))
|
||||||
transformer.build(file, wav_path)
|
transformer.build(file, wav_path)
|
||||||
wav_files.append(wav_path)
|
wav_files.append(wav_path)
|
||||||
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
||||||
|
@ -358,8 +353,8 @@ def assign_sub_sets(samples):
|
||||||
def create_sample_dirs(language):
|
def create_sample_dirs(language):
|
||||||
print('Creating sample directories...')
|
print('Creating sample directories...')
|
||||||
for set_name in ['train', 'dev', 'test']:
|
for set_name in ['train', 'dev', 'test']:
|
||||||
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
dir_path = os.path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||||
if not path.isdir(dir_path):
|
if not os.path.isdir(dir_path):
|
||||||
os.mkdir(dir_path)
|
os.mkdir(dir_path)
|
||||||
|
|
||||||
|
|
||||||
|
@ -374,7 +369,7 @@ def split_audio_files(samples, language):
|
||||||
rate = src_wav_file.getframerate()
|
rate = src_wav_file.getframerate()
|
||||||
for sample in file_samples:
|
for sample in file_samples:
|
||||||
index = sub_sets[sample.sub_set]
|
index = sub_sets[sample.sub_set]
|
||||||
sample_wav_path = path.join(CLI_ARGS.base_dir,
|
sample_wav_path = os.path.join(CLI_ARGS.base_dir,
|
||||||
language + '-' + sample.sub_set,
|
language + '-' + sample.sub_set,
|
||||||
'sample-{0:06d}.wav'.format(index))
|
'sample-{0:06d}.wav'.format(index))
|
||||||
sample.wav_path = sample_wav_path
|
sample.wav_path = sample_wav_path
|
||||||
|
@ -391,8 +386,8 @@ def split_audio_files(samples, language):
|
||||||
def write_csvs(samples, language):
|
def write_csvs(samples, language):
|
||||||
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
||||||
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
||||||
base_dir = path.abspath(CLI_ARGS.base_dir)
|
base_dir = os.path.abspath(CLI_ARGS.base_dir)
|
||||||
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
|
csv_path = os.path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||||
print('Writing "{}"...'.format(csv_path))
|
print('Writing "{}"...'.format(csv_path))
|
||||||
with open(csv_path, 'w') as csv_file:
|
with open(csv_path, 'w') as csv_file:
|
||||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
|
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
|
||||||
|
@ -400,8 +395,8 @@ def write_csvs(samples, language):
|
||||||
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
||||||
for sample in bar(set_samples):
|
for sample in bar(set_samples):
|
||||||
row = {
|
row = {
|
||||||
'wav_filename': path.relpath(sample.wav_path, base_dir),
|
'wav_filename': os.path.relpath(sample.wav_path, base_dir),
|
||||||
'wav_filesize': path.getsize(sample.wav_path),
|
'wav_filesize': os.path.getsize(sample.wav_path),
|
||||||
'transcript': sample.text
|
'transcript': sample.text
|
||||||
}
|
}
|
||||||
if CLI_ARGS.add_meta:
|
if CLI_ARGS.add_meta:
|
||||||
|
@ -414,8 +409,8 @@ def cleanup(archive, language):
|
||||||
if not CLI_ARGS.keep_archive:
|
if not CLI_ARGS.keep_archive:
|
||||||
print('Removing archive "{}"...'.format(archive))
|
print('Removing archive "{}"...'.format(archive))
|
||||||
os.remove(archive)
|
os.remove(archive)
|
||||||
language_dir = path.join(CLI_ARGS.base_dir, language)
|
language_dir = os.path.join(CLI_ARGS.base_dir, language)
|
||||||
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
|
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
|
||||||
print('Removing intermediate files in "{}"...'.format(language_dir))
|
print('Removing intermediate files in "{}"...'.format(language_dir))
|
||||||
shutil.rmtree(language_dir)
|
shutil.rmtree(language_dir)
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,8 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import codecs
|
|
||||||
import pandas
|
import pandas
|
||||||
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import wave
|
import wave
|
||||||
|
@ -16,9 +10,10 @@ import wave
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from os import makedirs, path, remove, rmdir
|
from os import makedirs, path, remove, rmdir
|
||||||
from sox import Transformer
|
from sox import Transformer
|
||||||
from util.downloader import maybe_download
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from util.stm import parse_stm_file
|
from deepspeech_training.util.stm import parse_stm_file
|
||||||
|
|
||||||
|
|
||||||
def _download_and_preprocess_data(data_dir):
|
def _download_and_preprocess_data(data_dir):
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
|
|
|
@ -1,28 +1,20 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import unidecode
|
import os
|
||||||
import zipfile
|
import progressbar
|
||||||
|
import re
|
||||||
import sox
|
import sox
|
||||||
import subprocess
|
import subprocess
|
||||||
import progressbar
|
import unidecode
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
|
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||||
|
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from util.downloader import SIMPLE_BAR
|
|
||||||
|
|
||||||
from os import path
|
|
||||||
|
|
||||||
from util.downloader import maybe_download
|
|
||||||
|
|
||||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
@ -34,7 +26,7 @@ ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE
|
||||||
|
|
||||||
def _download_and_preprocess_data(target_dir, english_compatible=False):
|
def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||||
# Making path absolute
|
# Making path absolute
|
||||||
target_dir = path.abspath(target_dir)
|
target_dir = os.path.abspath(target_dir)
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
|
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
|
||||||
# Conditionally extract archive data
|
# Conditionally extract archive data
|
||||||
|
@ -45,8 +37,8 @@ def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||||
|
|
||||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||||
extracted_path = path.join(target_dir, extracted_data)
|
extracted_path = os.path.join(target_dir, extracted_data)
|
||||||
if not path.exists(extracted_path):
|
if not os.path.exists(extracted_path):
|
||||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||||
if not os.path.isdir(extracted_path):
|
if not os.path.isdir(extracted_path):
|
||||||
os.mkdir(extracted_path)
|
os.mkdir(extracted_path)
|
||||||
|
@ -60,12 +52,12 @@ def one_sample(sample):
|
||||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||||
orig_filename = sample['path']
|
orig_filename = sample['path']
|
||||||
# Storing wav files next to the wav ones - just with a different suffix
|
# Storing wav files next to the wav ones - just with a different suffix
|
||||||
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
|
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
|
||||||
_maybe_convert_wav(orig_filename, wav_filename)
|
_maybe_convert_wav(orig_filename, wav_filename)
|
||||||
file_size = -1
|
file_size = -1
|
||||||
frames = 0
|
frames = 0
|
||||||
if path.exists(wav_filename):
|
if os.path.exists(wav_filename):
|
||||||
file_size = path.getsize(wav_filename)
|
file_size = os.path.getsize(wav_filename)
|
||||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||||
label = sample['text']
|
label = sample['text']
|
||||||
|
|
||||||
|
@ -95,7 +87,7 @@ def one_sample(sample):
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||||
extracted_dir = path.join(target_dir, extracted_data)
|
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||||
# override existing CSV with normalized one
|
# override existing CSV with normalized one
|
||||||
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
|
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
|
||||||
if os.path.isfile(target_csv_template):
|
if os.path.isfile(target_csv_template):
|
||||||
|
@ -160,7 +152,7 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||||
|
|
||||||
def _maybe_convert_wav(orig_filename, wav_filename):
|
def _maybe_convert_wav(orig_filename, wav_filename):
|
||||||
if not path.exists(wav_filename):
|
if not os.path.exists(wav_filename):
|
||||||
transformer = sox.Transformer()
|
transformer = sox.Transformer()
|
||||||
transformer.convert(samplerate=SAMPLE_RATE)
|
transformer.convert(samplerate=SAMPLE_RATE)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -5,25 +5,19 @@ Use "python3 import_tuda.py -h" for help
|
||||||
'''
|
'''
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import csv
|
|
||||||
import wave
|
|
||||||
import tarfile
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
import progressbar
|
import progressbar
|
||||||
|
import tarfile
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
import wave
|
||||||
import xml.etree.cElementTree as ET
|
import xml.etree.cElementTree as ET
|
||||||
|
|
||||||
from os import path
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from util.text import Alphabet
|
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||||
from util.importers import validate_label_eng as validate_label
|
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||||
from util.downloader import maybe_download, SIMPLE_BAR
|
from deepspeech_training.util.text import Alphabet
|
||||||
|
|
||||||
TUDA_VERSION = 'v2'
|
TUDA_VERSION = 'v2'
|
||||||
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
|
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
|
||||||
|
@ -38,8 +32,8 @@ FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||||
|
|
||||||
|
|
||||||
def maybe_extract(archive):
|
def maybe_extract(archive):
|
||||||
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||||
if path.isdir(extracted):
|
if os.path.isdir(extracted):
|
||||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||||
else:
|
else:
|
||||||
print('Extracting "{}"...'.format(archive))
|
print('Extracting "{}"...'.format(archive))
|
||||||
|
@ -92,7 +86,7 @@ def write_csvs(extracted):
|
||||||
sample_counter = 0
|
sample_counter = 0
|
||||||
reasons = Counter()
|
reasons = Counter()
|
||||||
for sub_set in ['train', 'dev', 'test']:
|
for sub_set in ['train', 'dev', 'test']:
|
||||||
set_path = path.join(extracted, sub_set)
|
set_path = os.path.join(extracted, sub_set)
|
||||||
set_files = os.listdir(set_path)
|
set_files = os.listdir(set_path)
|
||||||
recordings = {}
|
recordings = {}
|
||||||
for file in set_files:
|
for file in set_files:
|
||||||
|
@ -104,15 +98,15 @@ def write_csvs(extracted):
|
||||||
if prefix in recordings:
|
if prefix in recordings:
|
||||||
recordings[prefix].append(file)
|
recordings[prefix].append(file)
|
||||||
recordings = recordings.items()
|
recordings = recordings.items()
|
||||||
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
csv_path = os.path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
||||||
print('Writing "{}"...'.format(csv_path))
|
print('Writing "{}"...'.format(csv_path))
|
||||||
with open(csv_path, 'w') as csv_file:
|
with open(csv_path, 'w') as csv_file:
|
||||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
set_dir = path.join(extracted, sub_set)
|
set_dir = os.path.join(extracted, sub_set)
|
||||||
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
|
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
|
||||||
for prefix, wav_names in bar(recordings):
|
for prefix, wav_names in bar(recordings):
|
||||||
xml_path = path.join(set_dir, prefix + '.xml')
|
xml_path = os.path.join(set_dir, prefix + '.xml')
|
||||||
meta = ET.parse(xml_path).getroot()
|
meta = ET.parse(xml_path).getroot()
|
||||||
sentence = list(meta.iter('cleaned_sentence'))[0].text
|
sentence = list(meta.iter('cleaned_sentence'))[0].text
|
||||||
sentence = check_and_prepare_sentence(sentence)
|
sentence = check_and_prepare_sentence(sentence)
|
||||||
|
@ -120,12 +114,12 @@ def write_csvs(extracted):
|
||||||
continue
|
continue
|
||||||
for wav_name in wav_names:
|
for wav_name in wav_names:
|
||||||
sample_counter += 1
|
sample_counter += 1
|
||||||
wav_path = path.join(set_path, wav_name)
|
wav_path = os.path.join(set_path, wav_name)
|
||||||
keep, reason = check_wav_file(wav_path, sentence)
|
keep, reason = check_wav_file(wav_path, sentence)
|
||||||
if keep:
|
if keep:
|
||||||
writer.writerow({
|
writer.writerow({
|
||||||
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
|
'wav_filename': os.path.relpath(wav_path, CLI_ARGS.base_dir),
|
||||||
'wav_filesize': path.getsize(wav_path),
|
'wav_filesize': os.path.getsize(wav_path),
|
||||||
'transcript': sentence.lower()
|
'transcript': sentence.lower()
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,30 +1,21 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
|
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
|
||||||
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
|
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
|
||||||
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
|
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], ".."))
|
|
||||||
|
|
||||||
from util.importers import get_counter, get_imported_samples, print_import_report
|
|
||||||
|
|
||||||
import re
|
|
||||||
import librosa
|
import librosa
|
||||||
|
import os
|
||||||
import progressbar
|
import progressbar
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
from os import path
|
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||||
|
from deepspeech_training.util.importers import get_counter, get_imported_samples, print_import_report
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from util.downloader import maybe_download, SIMPLE_BAR
|
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
MAX_SECS = 10
|
MAX_SECS = 10
|
||||||
MIN_SECS = 1
|
MIN_SECS = 1
|
||||||
|
@ -37,7 +28,7 @@ ARCHIVE_URL = (
|
||||||
|
|
||||||
def _download_and_preprocess_data(target_dir):
|
def _download_and_preprocess_data(target_dir):
|
||||||
# Making path absolute
|
# Making path absolute
|
||||||
target_dir = path.abspath(target_dir)
|
target_dir = os.path.abspath(target_dir)
|
||||||
# Conditionally download data
|
# Conditionally download data
|
||||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||||
# Conditionally extract common voice data
|
# Conditionally extract common voice data
|
||||||
|
@ -48,8 +39,8 @@ def _download_and_preprocess_data(target_dir):
|
||||||
|
|
||||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||||
extracted_path = path.join(target_dir, extracted_data)
|
extracted_path = os.path.join(target_dir, extracted_data)
|
||||||
if not path.exists(extracted_path):
|
if not os.path.exists(extracted_path):
|
||||||
print(f"No directory {extracted_path} - extracting archive...")
|
print(f"No directory {extracted_path} - extracting archive...")
|
||||||
with ZipFile(archive_path, "r") as zipobj:
|
with ZipFile(archive_path, "r") as zipobj:
|
||||||
# Extract all the contents of zip file in current directory
|
# Extract all the contents of zip file in current directory
|
||||||
|
@ -59,8 +50,8 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_sets(target_dir, extracted_data):
|
def _maybe_convert_sets(target_dir, extracted_data):
|
||||||
extracted_dir = path.join(target_dir, extracted_data, "wav48")
|
extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
|
||||||
txt_dir = path.join(target_dir, extracted_data, "txt")
|
txt_dir = os.path.join(target_dir, extracted_data, "txt")
|
||||||
|
|
||||||
directory = os.path.expanduser(extracted_dir)
|
directory = os.path.expanduser(extracted_dir)
|
||||||
srtd = len(sorted(os.listdir(directory)))
|
srtd = len(sorted(os.listdir(directory)))
|
||||||
|
|
|
@ -2,23 +2,20 @@
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import tarfile
|
|
||||||
import pandas
|
import pandas
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import tarfile
|
||||||
import threading
|
import threading
|
||||||
from multiprocessing.pool import ThreadPool
|
import unicodedata
|
||||||
|
|
||||||
from six.moves import urllib
|
|
||||||
from glob import glob
|
|
||||||
from os import makedirs, path
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from deepspeech_training.util.downloader import maybe_download
|
||||||
|
from glob import glob
|
||||||
|
from multiprocessing.pool import ThreadPool
|
||||||
|
from os import makedirs, path
|
||||||
|
from six.moves import urllib
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from util.downloader import maybe_download
|
|
||||||
|
|
||||||
"""The number of jobs to run in parallel"""
|
"""The number of jobs to run in parallel"""
|
||||||
NUM_PARALLEL = 8
|
NUM_PARALLEL = 8
|
||||||
|
@ -99,7 +96,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||||
dataset_dir = path.join(data_dir, "dev")
|
dataset_dir = path.join(data_dir, "dev")
|
||||||
else:
|
else:
|
||||||
dataset_dir = path.join(data_dir, "train")
|
dataset_dir = path.join(data_dir, "train")
|
||||||
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
if not gfile.Exists(os.path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
||||||
c = counter.increment()
|
c = counter.increment()
|
||||||
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
|
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
|
||||||
tar = tarfile.open(archive)
|
tar = tarfile.open(archive)
|
||||||
|
@ -132,14 +129,14 @@ def _download_and_preprocess_data(data_dir):
|
||||||
p.map(downloader, enumerate(refs))
|
p.map(downloader, enumerate(refs))
|
||||||
|
|
||||||
# Conditionally extract data to dataset_dir
|
# Conditionally extract data to dataset_dir
|
||||||
if not path.isdir(path.join(data_dir,"test")):
|
if not path.isdir(os.path.join(data_dir, "test")):
|
||||||
makedirs(path.join(data_dir,"test"))
|
makedirs(os.path.join(data_dir, "test"))
|
||||||
if not path.isdir(path.join(data_dir,"dev")):
|
if not path.isdir(os.path.join(data_dir, "dev")):
|
||||||
makedirs(path.join(data_dir,"dev"))
|
makedirs(os.path.join(data_dir, "dev"))
|
||||||
if not path.isdir(path.join(data_dir,"train")):
|
if not path.isdir(os.path.join(data_dir, "train")):
|
||||||
makedirs(path.join(data_dir,"train"))
|
makedirs(os.path.join(data_dir, "train"))
|
||||||
|
|
||||||
tarfiles = glob(path.join(archive_dir, "*.tgz"))
|
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
|
||||||
number_of_files = len(tarfiles)
|
number_of_files = len(tarfiles)
|
||||||
number_of_test = number_of_files//100
|
number_of_test = number_of_files//100
|
||||||
number_of_dev = number_of_files//100
|
number_of_dev = number_of_files//100
|
||||||
|
@ -156,20 +153,20 @@ def _download_and_preprocess_data(data_dir):
|
||||||
train_files = _generate_dataset(data_dir, "train")
|
train_files = _generate_dataset(data_dir, "train")
|
||||||
|
|
||||||
# Write sets to disk as CSV files
|
# Write sets to disk as CSV files
|
||||||
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False)
|
train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||||
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False)
|
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||||
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False)
|
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||||
|
|
||||||
def _generate_dataset(data_dir, data_set):
|
def _generate_dataset(data_dir, data_set):
|
||||||
extracted_dir = path.join(data_dir, data_set)
|
extracted_dir = path.join(data_dir, data_set)
|
||||||
files = []
|
files = []
|
||||||
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
for promts_file in glob(os.path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
||||||
if path.isdir(path.join(promts_file[:-11],"wav")):
|
if path.isdir(os.path.join(promts_file[:-11], "wav")):
|
||||||
with codecs.open(promts_file, 'r', 'utf-8') as f:
|
with codecs.open(promts_file, 'r', 'utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
id = line.split(' ')[0].split('/')[-1]
|
id = line.split(' ')[0].split('/')[-1]
|
||||||
sentence = ' '.join(line.split(' ')[1:])
|
sentence = ' '.join(line.split(' ')[1:])
|
||||||
sentence = re.sub("[^a-z']"," ",sentence.strip().lower())
|
sentence = re.sub("[^a-z']", " ",sentence.strip().lower())
|
||||||
transcript = ""
|
transcript = ""
|
||||||
for token in sentence.split(" "):
|
for token in sentence.split(" "):
|
||||||
word = token.strip()
|
word = token.strip()
|
||||||
|
@ -178,14 +175,14 @@ def _generate_dataset(data_dir, data_set):
|
||||||
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
|
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
|
||||||
.encode("ascii", "ignore") \
|
.encode("ascii", "ignore") \
|
||||||
.decode("ascii", "ignore")
|
.decode("ascii", "ignore")
|
||||||
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav")
|
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
|
||||||
if gfile.Exists(wav_file):
|
if gfile.Exists(wav_file):
|
||||||
wav_filesize = path.getsize(wav_file)
|
wav_filesize = path.getsize(wav_file)
|
||||||
# remove audios that are shorter than 0.5s and longer than 20s.
|
# remove audios that are shorter than 0.5s and longer than 20s.
|
||||||
# remove audios that are too short for transcript.
|
# remove audios that are too short for transcript.
|
||||||
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \
|
if ((wav_filesize/32000) > 0.5 and (wav_filesize/32000) < 20 and transcript != "" and
|
||||||
wav_filesize/len(transcript)>1400:
|
wav_filesize/len(transcript) > 1400):
|
||||||
files.append((path.abspath(wav_file), wav_filesize, transcript))
|
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
|
||||||
|
|
||||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||||
|
|
||||||
|
|
13
bin/play.py
13
bin/play.py
|
@ -5,17 +5,12 @@ Use "python3 build_sdb.py -h" for help
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|
||||||
|
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
from util.sample_collections import samples_from_file, LabeledSample
|
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
|
||||||
from util.audio import AUDIO_TYPE_PCM
|
from deepspeech_training.util.sample_collections import samples_from_file, LabeledSample
|
||||||
|
|
||||||
|
|
||||||
def play_sample(samples, index):
|
def play_sample(samples, index):
|
||||||
|
|
|
@ -1,17 +1,11 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# Make sure we can import stuff from util/
|
|
||||||
# This script needs to be run from the root of the DeepSpeech repository
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
from util.text import Alphabet, UTF8Alphabet
|
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
||||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/
|
$ python3 -m venv $HOME/tmp/deepspeech-train-venv/
|
||||||
|
|
||||||
Once this command completes successfully, the environment will be ready to be activated.
|
Once this command completes successfully, the environment will be ready to be activated.
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ Install the required dependencies using ``pip3``\ :
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
cd DeepSpeech
|
cd DeepSpeech
|
||||||
pip3 install -r requirements.txt
|
pip3 install -e .
|
||||||
|
|
||||||
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
|
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip3 uninstall tensorflow
|
pip3 uninstall tensorflow
|
||||||
pip3 install 'tensorflow-gpu==1.15.0'
|
pip3 install 'tensorflow-gpu==1.15.2'
|
||||||
|
|
||||||
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.
|
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.
|
||||||
|
|
||||||
|
|
|
@ -2,155 +2,11 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
|
|
||||||
import absl.app
|
|
||||||
import progressbar
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow.compat.v1 as tfv1
|
|
||||||
|
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
|
||||||
from six.moves import zip
|
|
||||||
|
|
||||||
from util.config import Config, initialize_globals
|
|
||||||
from util.checkpoints import load_or_init_graph
|
|
||||||
from util.evaluate_tools import calculate_and_print_report
|
|
||||||
from util.feeding import create_dataset
|
|
||||||
from util.flags import create_flags, FLAGS
|
|
||||||
from util.helpers import check_ctcdecoder_version
|
|
||||||
from util.logging import create_progressbar, log_error, log_progress
|
|
||||||
|
|
||||||
check_ctcdecoder_version()
|
|
||||||
|
|
||||||
def sparse_tensor_value_to_texts(value, alphabet):
|
|
||||||
r"""
|
|
||||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
|
||||||
representing its values, converting tokens to strings using ``alphabet``.
|
|
||||||
"""
|
|
||||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
|
||||||
indices = sp_tuple[0]
|
|
||||||
values = sp_tuple[1]
|
|
||||||
results = [[] for _ in range(sp_tuple[2][0])]
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
results[index[0]].append(values[i])
|
|
||||||
# List of strings
|
|
||||||
return [alphabet.decode(res) for res in results]
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_csvs, create_model):
|
|
||||||
if FLAGS.scorer_path:
|
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
|
||||||
FLAGS.scorer_path, Config.alphabet)
|
|
||||||
else:
|
|
||||||
scorer = None
|
|
||||||
|
|
||||||
test_csvs = FLAGS.test_files.split(',')
|
|
||||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
|
||||||
tfv1.data.get_output_shapes(test_sets[0]),
|
|
||||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
|
||||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
|
||||||
|
|
||||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
|
||||||
|
|
||||||
# One rate per layer
|
|
||||||
no_dropout = [None] * 6
|
|
||||||
logits, _ = create_model(batch_x=batch_x,
|
|
||||||
batch_size=FLAGS.test_batch_size,
|
|
||||||
seq_length=batch_x_len,
|
|
||||||
dropout=no_dropout)
|
|
||||||
|
|
||||||
# Transpose to batch major and apply softmax for decoder
|
|
||||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
|
||||||
|
|
||||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
|
||||||
inputs=logits,
|
|
||||||
sequence_length=batch_x_len)
|
|
||||||
|
|
||||||
tfv1.train.get_or_create_global_step()
|
|
||||||
|
|
||||||
# Get number of accessible CPU cores for this process
|
|
||||||
try:
|
|
||||||
num_processes = cpu_count()
|
|
||||||
except NotImplementedError:
|
|
||||||
num_processes = 1
|
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
|
||||||
if FLAGS.load == 'auto':
|
|
||||||
method_order = ['best', 'last']
|
|
||||||
else:
|
|
||||||
method_order = [FLAGS.load]
|
|
||||||
load_or_init_graph(session, method_order)
|
|
||||||
|
|
||||||
def run_test(init_op, dataset):
|
|
||||||
wav_filenames = []
|
|
||||||
losses = []
|
|
||||||
predictions = []
|
|
||||||
ground_truths = []
|
|
||||||
|
|
||||||
bar = create_progressbar(prefix='Test epoch | ',
|
|
||||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
|
||||||
log_progress('Test epoch...')
|
|
||||||
|
|
||||||
step_count = 0
|
|
||||||
|
|
||||||
# Initialize iterator to the appropriate dataset
|
|
||||||
session.run(init_op)
|
|
||||||
|
|
||||||
# First pass, compute losses and transposed logits for decoding
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
|
||||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
|
||||||
except tf.errors.OutOfRangeError:
|
|
||||||
break
|
|
||||||
|
|
||||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
|
||||||
num_processes=num_processes, scorer=scorer,
|
|
||||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
|
||||||
predictions.extend(d[0][1] for d in decoded)
|
|
||||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
|
||||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
|
||||||
losses.extend(batch_loss)
|
|
||||||
|
|
||||||
step_count += 1
|
|
||||||
bar.update(step_count)
|
|
||||||
|
|
||||||
bar.finish()
|
|
||||||
|
|
||||||
# Print test summary
|
|
||||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
|
||||||
return test_samples
|
|
||||||
|
|
||||||
samples = []
|
|
||||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
|
||||||
print('Testing model on {}'.format(csv))
|
|
||||||
samples.extend(run_test(init_op, dataset=csv))
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
|
||||||
initialize_globals()
|
|
||||||
|
|
||||||
if not FLAGS.test_files:
|
|
||||||
log_error('You need to specify what files to use for evaluation via '
|
|
||||||
'the --test_files flag.')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
|
||||||
|
|
||||||
if FLAGS.test_output_file:
|
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
|
||||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
create_flags()
|
try:
|
||||||
absl.app.run(main)
|
from deepspeech_training import evaluate as ds_evaluate
|
||||||
|
except ImportError:
|
||||||
|
print('Training package is not installed. See training documentation.')
|
||||||
|
raise
|
||||||
|
|
||||||
|
ds_evaluate.run_script()
|
||||||
|
|
|
@ -10,13 +10,12 @@ import csv
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from six.moves import zip, range
|
|
||||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
|
||||||
from deepspeech import Model
|
from deepspeech import Model
|
||||||
|
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
|
||||||
from util.evaluate_tools import calculate_and_print_report
|
from deepspeech_training.util.flags import create_flags
|
||||||
from util.flags import create_flags
|
from functools import partial
|
||||||
|
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||||
|
from six.moves import zip, range
|
||||||
|
|
||||||
r'''
|
r'''
|
||||||
This module should be self-contained:
|
This module should be self-contained:
|
||||||
|
|
|
@ -2,19 +2,18 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, print_function
|
from __future__ import absolute_import, print_function
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import optuna
|
|
||||||
import absl.app
|
import absl.app
|
||||||
from ds_ctcdecoder import Scorer
|
import optuna
|
||||||
|
import sys
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
from DeepSpeech import create_model
|
from deepspeech_training.evaluate import evaluate
|
||||||
from evaluate import evaluate
|
from deepspeech_training.train import create_model
|
||||||
from util.config import Config, initialize_globals
|
from deepspeech_training.util.config import Config, initialize_globals
|
||||||
from util.flags import create_flags, FLAGS
|
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_error
|
from deepspeech_training.util.logging import log_error
|
||||||
from util.evaluate_tools import wer_cer_batch
|
from deepspeech_training.util.evaluate_tools import wer_cer_batch
|
||||||
|
from ds_ctcdecoder import Scorer
|
||||||
|
|
||||||
|
|
||||||
def character_based():
|
def character_based():
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
# Main training requirements
|
|
||||||
tensorflow == 1.15.2
|
|
||||||
numpy == 1.18.1
|
|
||||||
progressbar2
|
|
||||||
six
|
|
||||||
pyxdg
|
|
||||||
attrdict
|
|
||||||
absl-py
|
|
||||||
semver
|
|
||||||
opuslib == 2.0.0
|
|
||||||
|
|
||||||
# Requirements for building native_client files
|
|
||||||
setuptools
|
|
||||||
|
|
||||||
# Requirements for importers
|
|
||||||
sox
|
|
||||||
bs4
|
|
||||||
pandas
|
|
||||||
requests
|
|
||||||
librosa
|
|
||||||
soundfile
|
|
||||||
|
|
||||||
# Requirements for optimizer
|
|
||||||
optuna
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
def main():
|
||||||
|
setup(
|
||||||
|
name='deepspeech_training',
|
||||||
|
version='0.0.1',
|
||||||
|
description='Training code for mozilla DeepSpeech',
|
||||||
|
url='https://github.com/mozilla/DeepSpeech',
|
||||||
|
author='Mozilla',
|
||||||
|
license='MPL-2.0',
|
||||||
|
# Classifiers help users find your project by categorizing it.
|
||||||
|
#
|
||||||
|
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 3 - Alpha',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Topic :: Multimedia :: Sound/Audio :: Speech',
|
||||||
|
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
],
|
||||||
|
package_dir={'': 'training'},
|
||||||
|
packages=find_packages(where='training'),
|
||||||
|
python_requires='>=3.5, <4',
|
||||||
|
install_requires=[
|
||||||
|
'tensorflow == 1.15.2',
|
||||||
|
'numpy == 1.18.1',
|
||||||
|
'progressbar2',
|
||||||
|
'six',
|
||||||
|
'pyxdg',
|
||||||
|
'attrdict',
|
||||||
|
'absl-py',
|
||||||
|
'semver',
|
||||||
|
'opuslib == 2.0.0',
|
||||||
|
'optuna',
|
||||||
|
'sox',
|
||||||
|
'bs4',
|
||||||
|
'pandas',
|
||||||
|
'requests',
|
||||||
|
'librosa',
|
||||||
|
'soundfile',
|
||||||
|
],
|
||||||
|
# If there are data files included in your packages that need to be
|
||||||
|
# installed, specify them here.
|
||||||
|
package_data={
|
||||||
|
'deepspeech_training': [
|
||||||
|
'VERSION',
|
||||||
|
'GRAPH_VERSION',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
43
stats.py
43
stats.py
|
@ -1,10 +1,29 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import functools
|
||||||
|
import pandas
|
||||||
|
|
||||||
|
from deepspeech_training.util.helpers import secs_to_hours
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def read_csvs(csv_files):
|
||||||
|
# Relative paths are relative to CSV location
|
||||||
|
def absolutify(csv, path):
|
||||||
|
path = Path(path)
|
||||||
|
if path.is_absolute():
|
||||||
|
return str(path)
|
||||||
|
return str(csv.parent / path)
|
||||||
|
|
||||||
|
sets = []
|
||||||
|
for csv in csv_files:
|
||||||
|
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||||
|
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
|
||||||
|
sets.append(file)
|
||||||
|
|
||||||
|
# Concat all sets, drop any extra columns, re-index the final result as 0..N
|
||||||
|
return pandas.concat(sets, join='inner', ignore_index=True)
|
||||||
|
|
||||||
from util.helpers import secs_to_hours
|
|
||||||
from util.feeding import read_csvs
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -14,20 +33,16 @@ def main():
|
||||||
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
|
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
|
||||||
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
|
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
|
||||||
|
|
||||||
csv_dataframe = read_csvs(in_files)
|
csv_dataframe = read_csvs(in_files)
|
||||||
total_bytes = csv_dataframe['wav_filesize'].sum()
|
total_bytes = csv_dataframe['wav_filesize'].sum()
|
||||||
total_files = len(csv_dataframe.index)
|
total_files = len(csv_dataframe)
|
||||||
|
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
|
||||||
|
|
||||||
bytes_without_headers = total_bytes - 44 * total_files
|
print('Total bytes:', total_bytes)
|
||||||
|
print('Total files:', total_files)
|
||||||
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8)
|
print('Total time:', secs_to_hours(total_seconds))
|
||||||
|
|
||||||
print('total_bytes', total_bytes)
|
|
||||||
print('total_files', total_files)
|
|
||||||
print('bytes_without_headers', bytes_without_headers)
|
|
||||||
print('total_time', secs_to_hours(total_time))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
|
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
|
||||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
pushd ${HOME}/DeepSpeech/ds
|
||||||
|
pip install --upgrade . | cat
|
||||||
|
popd
|
||||||
set +o pipefail
|
set +o pipefail
|
||||||
|
|
||||||
which deepspeech
|
which deepspeech
|
||||||
|
|
|
@ -17,7 +17,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
pushd ${HOME}/DeepSpeech/ds
|
||||||
|
pip install --upgrade . | cat
|
||||||
|
popd
|
||||||
set +o pipefail
|
set +o pipefail
|
||||||
|
|
||||||
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
||||||
|
|
|
@ -16,7 +16,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
pushd ${HOME}/DeepSpeech/ds
|
||||||
|
pip install --upgrade . | cat
|
||||||
|
popd
|
||||||
set +o pipefail
|
set +o pipefail
|
||||||
|
|
||||||
pushd ${HOME}/DeepSpeech/ds/
|
pushd ${HOME}/DeepSpeech/ds/
|
||||||
|
|
|
@ -14,7 +14,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
pushd ${HOME}/DeepSpeech/ds
|
||||||
|
pip install --upgrade . | cat
|
||||||
|
popd
|
||||||
set +o pipefail
|
set +o pipefail
|
||||||
|
|
||||||
pushd ${HOME}/DeepSpeech/ds/
|
pushd ${HOME}/DeepSpeech/ds/
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from .importers import validate_label_eng, get_validate_label
|
from deepspeech_training.util.importers import validate_label_eng, get_validate_label
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def from_here(path):
|
||||||
|
here = Path(__file__)
|
||||||
|
return here.parent / path
|
||||||
|
|
||||||
class TestValidateLabelEng(unittest.TestCase):
|
class TestValidateLabelEng(unittest.TestCase):
|
||||||
|
|
||||||
def test_numbers(self):
|
def test_numbers(self):
|
||||||
label = validate_label_eng("this is a 1 2 3 test")
|
label = validate_label_eng("this is a 1 2 3 test")
|
||||||
self.assertEqual(label, None)
|
self.assertEqual(label, None)
|
||||||
|
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
|
||||||
self.assertEqual(f('toto1234[{[{[]'), None)
|
self.assertEqual(f('toto1234[{[{[]'), None)
|
||||||
|
|
||||||
def test_get_validate_label_missing(self):
|
def test_get_validate_label_missing(self):
|
||||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
|
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
|
||||||
f = get_validate_label(args)
|
f = get_validate_label(args)
|
||||||
self.assertEqual(f, None)
|
self.assertEqual(f, None)
|
||||||
|
|
||||||
def test_get_validate_label(self):
|
def test_get_validate_label(self):
|
||||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
|
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
|
||||||
f = get_validate_label(args)
|
f = get_validate_label(args)
|
||||||
l = f('toto')
|
l = f('toto')
|
||||||
self.assertEqual(l, 'toto')
|
self.assertEqual(l, 'toto')
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .text import Alphabet
|
from deepspeech_training.util.text import Alphabet
|
||||||
|
|
||||||
class TestAlphabetParsing(unittest.TestCase):
|
class TestAlphabetParsing(unittest.TestCase):
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
../../GRAPH_VERSION
|
|
@ -0,0 +1 @@
|
||||||
|
../../VERSION
|
|
@ -0,0 +1,159 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
|
import absl.app
|
||||||
|
import progressbar
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
|
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||||
|
from six.moves import zip
|
||||||
|
|
||||||
|
from .util.config import Config, initialize_globals
|
||||||
|
from .util.checkpoints import load_or_init_graph
|
||||||
|
from .util.evaluate_tools import calculate_and_print_report
|
||||||
|
from .util.feeding import create_dataset
|
||||||
|
from .util.flags import create_flags, FLAGS
|
||||||
|
from .util.helpers import check_ctcdecoder_version
|
||||||
|
from .util.logging import create_progressbar, log_error, log_progress
|
||||||
|
|
||||||
|
check_ctcdecoder_version()
|
||||||
|
|
||||||
|
def sparse_tensor_value_to_texts(value, alphabet):
|
||||||
|
r"""
|
||||||
|
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||||
|
representing its values, converting tokens to strings using ``alphabet``.
|
||||||
|
"""
|
||||||
|
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||||
|
indices = sp_tuple[0]
|
||||||
|
values = sp_tuple[1]
|
||||||
|
results = [[] for _ in range(sp_tuple[2][0])]
|
||||||
|
for i, index in enumerate(indices):
|
||||||
|
results[index[0]].append(values[i])
|
||||||
|
# List of strings
|
||||||
|
return [alphabet.decode(res) for res in results]
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(test_csvs, create_model):
|
||||||
|
if FLAGS.scorer_path:
|
||||||
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
|
FLAGS.scorer_path, Config.alphabet)
|
||||||
|
else:
|
||||||
|
scorer = None
|
||||||
|
|
||||||
|
test_csvs = FLAGS.test_files.split(',')
|
||||||
|
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||||
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||||
|
tfv1.data.get_output_shapes(test_sets[0]),
|
||||||
|
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||||
|
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||||
|
|
||||||
|
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
|
# One rate per layer
|
||||||
|
no_dropout = [None] * 6
|
||||||
|
logits, _ = create_model(batch_x=batch_x,
|
||||||
|
batch_size=FLAGS.test_batch_size,
|
||||||
|
seq_length=batch_x_len,
|
||||||
|
dropout=no_dropout)
|
||||||
|
|
||||||
|
# Transpose to batch major and apply softmax for decoder
|
||||||
|
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||||
|
|
||||||
|
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||||
|
inputs=logits,
|
||||||
|
sequence_length=batch_x_len)
|
||||||
|
|
||||||
|
tfv1.train.get_or_create_global_step()
|
||||||
|
|
||||||
|
# Get number of accessible CPU cores for this process
|
||||||
|
try:
|
||||||
|
num_processes = cpu_count()
|
||||||
|
except NotImplementedError:
|
||||||
|
num_processes = 1
|
||||||
|
|
||||||
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
|
if FLAGS.load == 'auto':
|
||||||
|
method_order = ['best', 'last']
|
||||||
|
else:
|
||||||
|
method_order = [FLAGS.load]
|
||||||
|
load_or_init_graph(session, method_order)
|
||||||
|
|
||||||
|
def run_test(init_op, dataset):
|
||||||
|
wav_filenames = []
|
||||||
|
losses = []
|
||||||
|
predictions = []
|
||||||
|
ground_truths = []
|
||||||
|
|
||||||
|
bar = create_progressbar(prefix='Test epoch | ',
|
||||||
|
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||||
|
log_progress('Test epoch...')
|
||||||
|
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
# Initialize iterator to the appropriate dataset
|
||||||
|
session.run(init_op)
|
||||||
|
|
||||||
|
# First pass, compute losses and transposed logits for decoding
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||||
|
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
break
|
||||||
|
|
||||||
|
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||||
|
num_processes=num_processes, scorer=scorer,
|
||||||
|
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||||
|
predictions.extend(d[0][1] for d in decoded)
|
||||||
|
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||||
|
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||||
|
losses.extend(batch_loss)
|
||||||
|
|
||||||
|
step_count += 1
|
||||||
|
bar.update(step_count)
|
||||||
|
|
||||||
|
bar.finish()
|
||||||
|
|
||||||
|
# Print test summary
|
||||||
|
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||||
|
return test_samples
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||||
|
print('Testing model on {}'.format(csv))
|
||||||
|
samples.extend(run_test(init_op, dataset=csv))
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
initialize_globals()
|
||||||
|
|
||||||
|
if not FLAGS.test_files:
|
||||||
|
log_error('You need to specify what files to use for evaluation via '
|
||||||
|
'the --test_files flag.')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
|
|
||||||
|
if FLAGS.test_output_file:
|
||||||
|
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||||
|
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||||
|
|
||||||
|
|
||||||
|
def run_script():
|
||||||
|
create_flags()
|
||||||
|
absl.app.run(main)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_script()
|
|
@ -0,0 +1,936 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||||
|
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
|
import absl.app
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import progressbar
|
||||||
|
import shutil
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
import time
|
||||||
|
|
||||||
|
tfv1.logging.set_verbosity({
|
||||||
|
'0': tfv1.logging.DEBUG,
|
||||||
|
'1': tfv1.logging.INFO,
|
||||||
|
'2': tfv1.logging.WARN,
|
||||||
|
'3': tfv1.logging.ERROR
|
||||||
|
}.get(DESIRED_LOG_LEVEL))
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
|
from .evaluate import evaluate
|
||||||
|
from six.moves import zip, range
|
||||||
|
from .util.config import Config, initialize_globals
|
||||||
|
from .util.checkpoints import load_or_init_graph
|
||||||
|
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
|
from .util.flags import create_flags, FLAGS
|
||||||
|
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||||
|
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||||
|
|
||||||
|
check_ctcdecoder_version()
|
||||||
|
|
||||||
|
# Graph Creation
|
||||||
|
# ==============
|
||||||
|
|
||||||
|
def variable_on_cpu(name, shape, initializer):
|
||||||
|
r"""
|
||||||
|
Next we concern ourselves with graph creation.
|
||||||
|
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||||
|
used to create a variable in CPU memory.
|
||||||
|
"""
|
||||||
|
# Use the /cpu:0 device for scoped operations
|
||||||
|
with tf.device(Config.cpu_device):
|
||||||
|
# Create or get apropos variable
|
||||||
|
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||||
|
return var
|
||||||
|
|
||||||
|
|
||||||
|
def create_overlapping_windows(batch_x):
|
||||||
|
batch_size = tf.shape(input=batch_x)[0]
|
||||||
|
window_width = 2 * Config.n_context + 1
|
||||||
|
num_channels = Config.n_input
|
||||||
|
|
||||||
|
# Create a constant convolution filter using an identity matrix, so that the
|
||||||
|
# convolution returns patches of the input tensor as is, and we can create
|
||||||
|
# overlapping windows over the MFCCs.
|
||||||
|
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||||
|
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||||
|
|
||||||
|
# Create overlapping windows
|
||||||
|
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||||
|
|
||||||
|
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||||
|
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||||
|
|
||||||
|
return batch_x
|
||||||
|
|
||||||
|
|
||||||
|
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||||
|
with tfv1.variable_scope(name):
|
||||||
|
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||||
|
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||||
|
|
||||||
|
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||||
|
|
||||||
|
if relu:
|
||||||
|
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||||
|
|
||||||
|
if dropout_rate is not None:
|
||||||
|
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||||
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||||
|
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||||
|
forget_bias=0,
|
||||||
|
reuse=reuse,
|
||||||
|
name='cudnn_compatible_lstm_cell')
|
||||||
|
|
||||||
|
output, output_state = fw_cell(inputs=x,
|
||||||
|
dtype=tf.float32,
|
||||||
|
sequence_length=seq_length,
|
||||||
|
initial_state=previous_state)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||||
|
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||||
|
|
||||||
|
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||||
|
# the object it creates the variables, and then you just call it several times
|
||||||
|
# to enable variable re-use. Because all of our code is structure in an old
|
||||||
|
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||||
|
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||||
|
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||||
|
# emulating a static function variable.
|
||||||
|
if not rnn_impl_cudnn_rnn.cell:
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||||
|
num_units=Config.n_cell_dim,
|
||||||
|
input_mode='linear_input',
|
||||||
|
direction='unidirectional',
|
||||||
|
dtype=tf.float32)
|
||||||
|
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||||
|
|
||||||
|
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||||
|
sequence_lengths=seq_length)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
rnn_impl_cudnn_rnn.cell = None
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||||
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||||
|
forget_bias=0,
|
||||||
|
reuse=reuse,
|
||||||
|
name='cudnn_compatible_lstm_cell')
|
||||||
|
|
||||||
|
# Split rank N tensor into list of rank N-1 tensors
|
||||||
|
x = [x[l] for l in range(x.shape[0])]
|
||||||
|
|
||||||
|
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||||
|
inputs=x,
|
||||||
|
sequence_length=seq_length,
|
||||||
|
initial_state=previous_state,
|
||||||
|
dtype=tf.float32,
|
||||||
|
scope='cell_0')
|
||||||
|
|
||||||
|
output = tf.concat(output, 0)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||||
|
layers = {}
|
||||||
|
|
||||||
|
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||||
|
if not batch_size:
|
||||||
|
batch_size = tf.shape(input=batch_x)[0]
|
||||||
|
|
||||||
|
# Create overlapping feature windows if needed
|
||||||
|
if overlap:
|
||||||
|
batch_x = create_overlapping_windows(batch_x)
|
||||||
|
|
||||||
|
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||||
|
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||||
|
|
||||||
|
# Permute n_steps and batch_size
|
||||||
|
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||||
|
# Reshape to prepare input for first layer
|
||||||
|
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||||
|
layers['input_reshaped'] = batch_x
|
||||||
|
|
||||||
|
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||||
|
# clipped RELU activation and dropout.
|
||||||
|
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||||
|
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||||
|
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||||
|
|
||||||
|
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||||
|
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||||
|
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||||
|
|
||||||
|
# Run through parametrized RNN implementation, as we use different RNNs
|
||||||
|
# for training and inference
|
||||||
|
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||||
|
|
||||||
|
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||||
|
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||||
|
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||||
|
layers['rnn_output'] = output
|
||||||
|
layers['rnn_output_state'] = output_state
|
||||||
|
|
||||||
|
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||||
|
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||||
|
|
||||||
|
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||||
|
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||||
|
|
||||||
|
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||||
|
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||||
|
# Note, that this differs from the input in that it is time-major.
|
||||||
|
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||||
|
layers['raw_logits'] = layer_6
|
||||||
|
|
||||||
|
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||||
|
return layer_6, layers
|
||||||
|
|
||||||
|
|
||||||
|
# Accuracy and Loss
|
||||||
|
# =================
|
||||||
|
|
||||||
|
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||||
|
# (http://arxiv.org/abs/1412.5567),
|
||||||
|
# the loss function used by our network should be the CTC loss function
|
||||||
|
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||||
|
# Conveniently, this loss function is implemented in TensorFlow.
|
||||||
|
# Thus, we can simply make use of this implementation to define our loss.
|
||||||
|
|
||||||
|
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||||
|
r'''
|
||||||
|
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||||
|
Next to total and average loss it returns the mean edit distance,
|
||||||
|
the decoded result and the batch's original Y.
|
||||||
|
'''
|
||||||
|
# Obtain the next batch of data
|
||||||
|
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
|
if FLAGS.train_cudnn:
|
||||||
|
rnn_impl = rnn_impl_cudnn_rnn
|
||||||
|
else:
|
||||||
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
|
||||||
|
# Calculate the logits of the batch
|
||||||
|
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||||
|
|
||||||
|
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||||
|
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||||
|
|
||||||
|
# Check if any files lead to non finite loss
|
||||||
|
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||||
|
|
||||||
|
# Calculate the average loss across the batch
|
||||||
|
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||||
|
|
||||||
|
# Finally we return the average loss
|
||||||
|
return avg_loss, non_finite_files
|
||||||
|
|
||||||
|
|
||||||
|
# Adam Optimization
|
||||||
|
# =================
|
||||||
|
|
||||||
|
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||||
|
# (http://arxiv.org/abs/1412.5567),
|
||||||
|
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||||
|
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||||
|
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||||
|
# because, generally, it requires less fine-tuning.
|
||||||
|
def create_optimizer(learning_rate_var):
|
||||||
|
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||||
|
beta1=FLAGS.beta1,
|
||||||
|
beta2=FLAGS.beta2,
|
||||||
|
epsilon=FLAGS.epsilon)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
# Towers
|
||||||
|
# ======
|
||||||
|
|
||||||
|
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||||
|
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||||
|
# In particular, one must introduce a means to isolate the inference and gradient
|
||||||
|
# calculations on the various GPU's.
|
||||||
|
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||||
|
# A tower is specified by two properties:
|
||||||
|
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||||
|
# is a means to isolate the operations within a tower.
|
||||||
|
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||||
|
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||||
|
# on which all operations within the tower execute.
|
||||||
|
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||||
|
|
||||||
|
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||||
|
r'''
|
||||||
|
With this preliminary step out of the way, we can for each GPU introduce a
|
||||||
|
tower for which's batch we calculate and return the optimization gradients
|
||||||
|
and the average loss across towers.
|
||||||
|
'''
|
||||||
|
# To calculate the mean of the losses
|
||||||
|
tower_avg_losses = []
|
||||||
|
|
||||||
|
# Tower gradients to return
|
||||||
|
tower_gradients = []
|
||||||
|
|
||||||
|
# Aggregate any non finite files in the batches
|
||||||
|
tower_non_finite_files = []
|
||||||
|
|
||||||
|
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||||
|
# Loop over available_devices
|
||||||
|
for i in range(len(Config.available_devices)):
|
||||||
|
# Execute operations of tower i on device i
|
||||||
|
device = Config.available_devices[i]
|
||||||
|
with tf.device(device):
|
||||||
|
# Create a scope for all operations of tower i
|
||||||
|
with tf.name_scope('tower_%d' % i):
|
||||||
|
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||||
|
# batch along with the original batch's labels (Y) of this tower
|
||||||
|
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||||
|
|
||||||
|
# Allow for variables to be re-used by the next tower
|
||||||
|
tfv1.get_variable_scope().reuse_variables()
|
||||||
|
|
||||||
|
# Retain tower's avg losses
|
||||||
|
tower_avg_losses.append(avg_loss)
|
||||||
|
|
||||||
|
# Compute gradients for model parameters using tower's mini-batch
|
||||||
|
gradients = optimizer.compute_gradients(avg_loss)
|
||||||
|
|
||||||
|
# Retain tower's gradients
|
||||||
|
tower_gradients.append(gradients)
|
||||||
|
|
||||||
|
tower_non_finite_files.append(non_finite_files)
|
||||||
|
|
||||||
|
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||||
|
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||||
|
|
||||||
|
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||||
|
|
||||||
|
# Return gradients and the average loss
|
||||||
|
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||||
|
|
||||||
|
|
||||||
|
def average_gradients(tower_gradients):
|
||||||
|
r'''
|
||||||
|
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||||
|
Note also that this code acts as a synchronization point as it requires all
|
||||||
|
GPUs to be finished with their mini-batch before it can run to completion.
|
||||||
|
'''
|
||||||
|
# List of average gradients to return to the caller
|
||||||
|
average_grads = []
|
||||||
|
|
||||||
|
# Run this on cpu_device to conserve GPU memory
|
||||||
|
with tf.device(Config.cpu_device):
|
||||||
|
# Loop over gradient/variable pairs from all towers
|
||||||
|
for grad_and_vars in zip(*tower_gradients):
|
||||||
|
# Introduce grads to store the gradients for the current variable
|
||||||
|
grads = []
|
||||||
|
|
||||||
|
# Loop over the gradients for the current variable
|
||||||
|
for g, _ in grad_and_vars:
|
||||||
|
# Add 0 dimension to the gradients to represent the tower.
|
||||||
|
expanded_g = tf.expand_dims(g, 0)
|
||||||
|
# Append on a 'tower' dimension which we will average over below.
|
||||||
|
grads.append(expanded_g)
|
||||||
|
|
||||||
|
# Average over the 'tower' dimension
|
||||||
|
grad = tf.concat(grads, 0)
|
||||||
|
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||||
|
|
||||||
|
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||||
|
grad_and_var = (grad, grad_and_vars[0][1])
|
||||||
|
|
||||||
|
# Add the current tuple to average_grads
|
||||||
|
average_grads.append(grad_and_var)
|
||||||
|
|
||||||
|
# Return result to caller
|
||||||
|
return average_grads
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
# =======
|
||||||
|
|
||||||
|
def log_variable(variable, gradient=None):
|
||||||
|
r'''
|
||||||
|
We introduce a function for logging a tensor variable's current state.
|
||||||
|
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||||
|
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||||
|
'''
|
||||||
|
name = variable.name.replace(':', '_')
|
||||||
|
mean = tf.reduce_mean(input_tensor=variable)
|
||||||
|
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||||
|
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||||
|
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||||
|
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||||
|
tfv1.summary.histogram(name=name, values=variable)
|
||||||
|
if gradient is not None:
|
||||||
|
if isinstance(gradient, tf.IndexedSlices):
|
||||||
|
grad_values = gradient.values
|
||||||
|
else:
|
||||||
|
grad_values = gradient
|
||||||
|
if grad_values is not None:
|
||||||
|
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||||
|
|
||||||
|
|
||||||
|
def log_grads_and_vars(grads_and_vars):
|
||||||
|
r'''
|
||||||
|
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||||
|
'''
|
||||||
|
for gradient, variable in grads_and_vars:
|
||||||
|
log_variable(variable, gradient=gradient)
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
do_cache_dataset = True
|
||||||
|
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||||
|
FLAGS.data_aug_features_additive > 0 or
|
||||||
|
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||||
|
FLAGS.augmentation_freq_and_time_masking or
|
||||||
|
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||||
|
FLAGS.augmentation_speed_up_std > 0 or
|
||||||
|
FLAGS.augmentation_sparse_warp):
|
||||||
|
do_cache_dataset = False
|
||||||
|
|
||||||
|
exception_box = ExceptionBox()
|
||||||
|
|
||||||
|
# Create training and validation datasets
|
||||||
|
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||||
|
batch_size=FLAGS.train_batch_size,
|
||||||
|
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||||
|
cache_path=FLAGS.feature_cache,
|
||||||
|
train_phase=True,
|
||||||
|
exception_box=exception_box,
|
||||||
|
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||||
|
buffering=FLAGS.read_buffer)
|
||||||
|
|
||||||
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||||
|
tfv1.data.get_output_shapes(train_set),
|
||||||
|
output_classes=tfv1.data.get_output_classes(train_set))
|
||||||
|
|
||||||
|
# Make initialization ops for switching between the two sets
|
||||||
|
train_init_op = iterator.make_initializer(train_set)
|
||||||
|
|
||||||
|
if FLAGS.dev_files:
|
||||||
|
dev_sources = FLAGS.dev_files.split(',')
|
||||||
|
dev_sets = [create_dataset([source],
|
||||||
|
batch_size=FLAGS.dev_batch_size,
|
||||||
|
train_phase=False,
|
||||||
|
exception_box=exception_box,
|
||||||
|
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||||
|
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||||
|
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||||
|
|
||||||
|
# Dropout
|
||||||
|
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||||
|
dropout_feed_dict = {
|
||||||
|
dropout_rates[0]: FLAGS.dropout_rate,
|
||||||
|
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||||
|
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||||
|
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||||
|
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||||
|
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||||
|
}
|
||||||
|
no_dropout_feed_dict = {
|
||||||
|
rate: 0. for rate in dropout_rates
|
||||||
|
}
|
||||||
|
|
||||||
|
# Building the graph
|
||||||
|
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||||
|
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||||
|
optimizer = create_optimizer(learning_rate_var)
|
||||||
|
|
||||||
|
# Enable mixed precision training
|
||||||
|
if FLAGS.automatic_mixed_precision:
|
||||||
|
log_info('Enabling automatic mixed precision training.')
|
||||||
|
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||||
|
|
||||||
|
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||||
|
|
||||||
|
# Average tower gradients across GPUs
|
||||||
|
avg_tower_gradients = average_gradients(gradients)
|
||||||
|
log_grads_and_vars(avg_tower_gradients)
|
||||||
|
|
||||||
|
# global_step is automagically incremented by the optimizer
|
||||||
|
global_step = tfv1.train.get_or_create_global_step()
|
||||||
|
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||||
|
|
||||||
|
# Summaries
|
||||||
|
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||||
|
step_summary_writers = {
|
||||||
|
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||||
|
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||||
|
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||||
|
|
||||||
|
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||||
|
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||||
|
|
||||||
|
# Save flags next to checkpoints
|
||||||
|
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||||
|
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||||
|
with open(flags_file, 'w') as fout:
|
||||||
|
fout.write(FLAGS.flags_into_string())
|
||||||
|
|
||||||
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
|
log_debug('Session opened.')
|
||||||
|
|
||||||
|
# Prevent further graph changes
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
|
# Load checkpoint or initialize variables
|
||||||
|
if FLAGS.load == 'auto':
|
||||||
|
method_order = ['best', 'last', 'init']
|
||||||
|
else:
|
||||||
|
method_order = [FLAGS.load]
|
||||||
|
load_or_init_graph(session, method_order)
|
||||||
|
|
||||||
|
def run_set(set_name, epoch, init_op, dataset=None):
|
||||||
|
is_train = set_name == 'train'
|
||||||
|
train_op = apply_gradient_op if is_train else []
|
||||||
|
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
step_summary_writer = step_summary_writers.get(set_name)
|
||||||
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
|
# Setup progress bar
|
||||||
|
class LossWidget(progressbar.widgets.FormatLabel):
|
||||||
|
def __init__(self):
|
||||||
|
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||||
|
|
||||||
|
def __call__(self, progress, data, **kwargs):
|
||||||
|
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||||
|
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||||
|
|
||||||
|
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||||
|
widgets = [' | ', progressbar.widgets.Timer(),
|
||||||
|
' | Steps: ', progressbar.widgets.Counter(),
|
||||||
|
' | ', LossWidget()]
|
||||||
|
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||||
|
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||||
|
|
||||||
|
# Initialize iterator to the appropriate dataset
|
||||||
|
session.run(init_op)
|
||||||
|
|
||||||
|
# Batch loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_, current_step, batch_loss, problem_files, step_summary = \
|
||||||
|
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||||
|
feed_dict=feed_dict)
|
||||||
|
exception_box.raise_if_set()
|
||||||
|
except tf.errors.InvalidArgumentError as err:
|
||||||
|
if FLAGS.augmentation_sparse_warp:
|
||||||
|
log_info("Ignoring sparse warp error: {}".format(err))
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
exception_box.raise_if_set()
|
||||||
|
break
|
||||||
|
|
||||||
|
if problem_files.size > 0:
|
||||||
|
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||||
|
log_error('The following files caused an infinite (or NaN) '
|
||||||
|
'loss: {}'.format(','.join(problem_files)))
|
||||||
|
|
||||||
|
total_loss += batch_loss
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
pbar.update(step_count)
|
||||||
|
|
||||||
|
step_summary_writer.add_summary(step_summary, current_step)
|
||||||
|
|
||||||
|
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||||
|
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||||
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
|
pbar.finish()
|
||||||
|
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||||
|
return mean_loss, step_count
|
||||||
|
|
||||||
|
log_info('STARTING Optimization')
|
||||||
|
train_start_time = datetime.utcnow()
|
||||||
|
best_dev_loss = float('inf')
|
||||||
|
dev_losses = []
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
try:
|
||||||
|
for epoch in range(FLAGS.epochs):
|
||||||
|
# Training
|
||||||
|
log_progress('Training epoch %d...' % epoch)
|
||||||
|
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||||
|
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||||
|
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||||
|
|
||||||
|
if FLAGS.dev_files:
|
||||||
|
# Validation
|
||||||
|
dev_loss = 0.0
|
||||||
|
total_steps = 0
|
||||||
|
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||||
|
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||||
|
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||||
|
dev_loss += set_loss * steps
|
||||||
|
total_steps += steps
|
||||||
|
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||||
|
|
||||||
|
dev_loss = dev_loss / total_steps
|
||||||
|
dev_losses.append(dev_loss)
|
||||||
|
|
||||||
|
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||||
|
# the improvement has to be greater than FLAGS.es_min_delta
|
||||||
|
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
else:
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
|
||||||
|
# Save new best model
|
||||||
|
if dev_loss < best_dev_loss:
|
||||||
|
best_dev_loss = dev_loss
|
||||||
|
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||||
|
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||||
|
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||||
|
epochs_without_improvement))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reduce learning rate on plateau
|
||||||
|
if (FLAGS.reduce_lr_on_plateau and
|
||||||
|
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||||
|
# If the learning rate was reduced and there is still no improvement
|
||||||
|
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||||
|
session.run(reduce_learning_rate_op)
|
||||||
|
current_learning_rate = learning_rate_var.eval()
|
||||||
|
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||||
|
current_learning_rate))
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||||
|
log_debug('Session closed.')
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
|
if FLAGS.test_output_file:
|
||||||
|
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||||
|
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||||
|
|
||||||
|
|
||||||
|
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
|
|
||||||
|
# Create feature computation graph
|
||||||
|
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||||
|
samples = tf.expand_dims(input_samples, -1)
|
||||||
|
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||||
|
mfccs = tf.identity(mfccs, name='mfccs')
|
||||||
|
|
||||||
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
|
# This shape is read by the native_client in DS_CreateModel to know the
|
||||||
|
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||||
|
# there if this shape is changed.
|
||||||
|
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||||
|
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||||
|
|
||||||
|
if batch_size <= 0:
|
||||||
|
# no state management since n_step is expected to be dynamic too (see below)
|
||||||
|
previous_state = None
|
||||||
|
else:
|
||||||
|
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||||
|
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||||
|
|
||||||
|
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||||
|
|
||||||
|
# One rate per layer
|
||||||
|
no_dropout = [None] * 6
|
||||||
|
|
||||||
|
if tflite:
|
||||||
|
rnn_impl = rnn_impl_static_rnn
|
||||||
|
else:
|
||||||
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
|
||||||
|
logits, layers = create_model(batch_x=input_tensor,
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||||
|
dropout=no_dropout,
|
||||||
|
previous_state=previous_state,
|
||||||
|
overlap=False,
|
||||||
|
rnn_impl=rnn_impl)
|
||||||
|
|
||||||
|
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||||
|
# by default we get 3, the middle one being batch_size which is forced to
|
||||||
|
# one on inference graph, so remove that dimension
|
||||||
|
if tflite:
|
||||||
|
logits = tf.squeeze(logits, [1])
|
||||||
|
|
||||||
|
# Apply softmax for CTC decoder
|
||||||
|
logits = tf.nn.softmax(logits, name='logits')
|
||||||
|
|
||||||
|
if batch_size <= 0:
|
||||||
|
if tflite:
|
||||||
|
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||||
|
if n_steps > 0:
|
||||||
|
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
'input': input_tensor,
|
||||||
|
'input_lengths': seq_length,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'outputs': logits,
|
||||||
|
},
|
||||||
|
layers
|
||||||
|
)
|
||||||
|
|
||||||
|
new_state_c, new_state_h = layers['rnn_output_state']
|
||||||
|
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||||
|
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
'input': input_tensor,
|
||||||
|
'previous_state_c': previous_state_c,
|
||||||
|
'previous_state_h': previous_state_h,
|
||||||
|
'input_samples': input_samples,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not FLAGS.export_tflite:
|
||||||
|
inputs['input_lengths'] = seq_length
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
'outputs': logits,
|
||||||
|
'new_state_c': new_state_c,
|
||||||
|
'new_state_h': new_state_h,
|
||||||
|
'mfccs': mfccs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs, outputs, layers
|
||||||
|
|
||||||
|
|
||||||
|
def file_relative_read(fname):
|
||||||
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||||
|
|
||||||
|
|
||||||
|
def export():
|
||||||
|
r'''
|
||||||
|
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||||
|
'''
|
||||||
|
log_info('Exporting the model...')
|
||||||
|
|
||||||
|
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||||
|
|
||||||
|
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||||
|
assert graph_version > 0
|
||||||
|
|
||||||
|
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||||
|
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||||
|
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||||
|
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||||
|
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||||
|
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||||
|
|
||||||
|
if FLAGS.export_language:
|
||||||
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||||
|
|
||||||
|
# Prevent further graph changes
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
|
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
|
||||||
|
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
|
||||||
|
output_names = output_names_tensors + output_names_ops
|
||||||
|
|
||||||
|
with tf.Session() as session:
|
||||||
|
# Restore variables from checkpoint
|
||||||
|
if FLAGS.load == 'auto':
|
||||||
|
method_order = ['best', 'last']
|
||||||
|
else:
|
||||||
|
method_order = [FLAGS.load]
|
||||||
|
load_or_init_graph(session, method_order)
|
||||||
|
|
||||||
|
output_filename = FLAGS.export_file_name + '.pb'
|
||||||
|
if FLAGS.remove_export:
|
||||||
|
if os.path.isdir(FLAGS.export_dir):
|
||||||
|
log_info('Removing old export')
|
||||||
|
shutil.rmtree(FLAGS.export_dir)
|
||||||
|
|
||||||
|
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||||
|
|
||||||
|
if not os.path.isdir(FLAGS.export_dir):
|
||||||
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
|
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||||
|
sess=session,
|
||||||
|
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||||
|
output_node_names=output_names)
|
||||||
|
|
||||||
|
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||||
|
graph_def=frozen_graph,
|
||||||
|
dest_nodes=output_names)
|
||||||
|
|
||||||
|
if not FLAGS.export_tflite:
|
||||||
|
with open(output_graph_path, 'wb') as fout:
|
||||||
|
fout.write(frozen_graph.SerializeToString())
|
||||||
|
else:
|
||||||
|
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||||
|
|
||||||
|
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||||
|
converter.allow_custom_ops = True
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
with open(output_tflite_path, 'wb') as fout:
|
||||||
|
fout.write(tflite_model)
|
||||||
|
|
||||||
|
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||||
|
|
||||||
|
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||||
|
FLAGS.export_author_id,
|
||||||
|
FLAGS.export_model_name,
|
||||||
|
FLAGS.export_model_version))
|
||||||
|
|
||||||
|
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||||
|
with open(metadata_fname, 'w') as f:
|
||||||
|
f.write('---\n')
|
||||||
|
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||||
|
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||||
|
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||||
|
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||||
|
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||||
|
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||||
|
f.write('runtime: {}\n'.format(model_runtime))
|
||||||
|
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||||
|
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||||
|
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||||
|
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||||
|
f.write('---\n')
|
||||||
|
f.write('{}\n'.format(FLAGS.export_description))
|
||||||
|
|
||||||
|
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||||
|
|
||||||
|
|
||||||
|
def package_zip():
|
||||||
|
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||||
|
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||||
|
zip_filename = os.path.dirname(export_dir)
|
||||||
|
|
||||||
|
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||||
|
|
||||||
|
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||||
|
log_info('Exported packaged model {}'.format(archive))
|
||||||
|
|
||||||
|
|
||||||
|
def do_single_file_inference(input_file_path):
|
||||||
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
|
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||||
|
|
||||||
|
# Restore variables from training checkpoint
|
||||||
|
if FLAGS.load == 'auto':
|
||||||
|
method_order = ['best', 'last']
|
||||||
|
else:
|
||||||
|
method_order = [FLAGS.load]
|
||||||
|
load_or_init_graph(session, method_order)
|
||||||
|
|
||||||
|
features, features_len = audiofile_to_features(input_file_path)
|
||||||
|
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||||
|
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
|
features = tf.expand_dims(features, 0)
|
||||||
|
features_len = tf.expand_dims(features_len, 0)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
features = create_overlapping_windows(features).eval(session=session)
|
||||||
|
features_len = features_len.eval(session=session)
|
||||||
|
|
||||||
|
logits = outputs['outputs'].eval(feed_dict={
|
||||||
|
inputs['input']: features,
|
||||||
|
inputs['input_lengths']: features_len,
|
||||||
|
inputs['previous_state_c']: previous_state_c,
|
||||||
|
inputs['previous_state_h']: previous_state_h,
|
||||||
|
}, session=session)
|
||||||
|
|
||||||
|
logits = np.squeeze(logits)
|
||||||
|
|
||||||
|
if FLAGS.scorer_path:
|
||||||
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
|
FLAGS.scorer_path, Config.alphabet)
|
||||||
|
else:
|
||||||
|
scorer = None
|
||||||
|
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||||
|
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||||
|
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||||
|
# Print highest probability result
|
||||||
|
print(decoded[0][1])
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
initialize_globals()
|
||||||
|
|
||||||
|
if FLAGS.train_files:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
tfv1.set_random_seed(FLAGS.random_seed)
|
||||||
|
train()
|
||||||
|
|
||||||
|
if FLAGS.test_files:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
test()
|
||||||
|
|
||||||
|
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
export()
|
||||||
|
|
||||||
|
if FLAGS.export_zip:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
FLAGS.export_tflite = True
|
||||||
|
|
||||||
|
if os.listdir(FLAGS.export_dir):
|
||||||
|
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
export()
|
||||||
|
package_zip()
|
||||||
|
|
||||||
|
if FLAGS.one_shot_infer:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
do_single_file_inference(FLAGS.one_shot_infer)
|
||||||
|
|
||||||
|
|
||||||
|
def run_script():
|
||||||
|
create_flags()
|
||||||
|
absl.app.run(main)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_script()
|
|
@ -5,7 +5,7 @@ import tempfile
|
||||||
import collections
|
import collections
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from util.helpers import LimitingPool
|
from .helpers import LimitingPool
|
||||||
|
|
||||||
DEFAULT_RATE = 16000
|
DEFAULT_RATE = 16000
|
||||||
DEFAULT_CHANNELS = 1
|
DEFAULT_CHANNELS = 1
|
|
@ -2,8 +2,8 @@ import sys
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
from util.flags import FLAGS
|
from .flags import FLAGS
|
||||||
from util.logging import log_info, log_error, log_warn
|
from .logging import log_info, log_error, log_warn
|
||||||
|
|
||||||
|
|
||||||
def _load_checkpoint(session, checkpoint_path):
|
def _load_checkpoint(session, checkpoint_path):
|
|
@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
from xdg import BaseDirectory as xdg
|
from xdg import BaseDirectory as xdg
|
||||||
|
|
||||||
from util.flags import FLAGS
|
from .flags import FLAGS
|
||||||
from util.gpu import get_available_gpus
|
from .gpu import get_available_gpus
|
||||||
from util.logging import log_error
|
from .logging import log_error
|
||||||
from util.text import Alphabet, UTF8Alphabet
|
from .text import Alphabet, UTF8Alphabet
|
||||||
from util.helpers import parse_file_size
|
from .helpers import parse_file_size
|
||||||
|
|
||||||
class ConfigSingleton:
|
class ConfigSingleton:
|
||||||
_config = None
|
_config = None
|
|
@ -7,8 +7,8 @@ import numpy as np
|
||||||
|
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
|
|
||||||
from util.flags import FLAGS
|
from .flags import FLAGS
|
||||||
from util.text import levenshtein
|
from .text import levenshtein
|
||||||
|
|
||||||
|
|
||||||
def pmap(fun, iterable):
|
def pmap(fun, iterable):
|
|
@ -8,13 +8,13 @@ import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||||
|
|
||||||
from util.config import Config
|
from .config import Config
|
||||||
from util.text import text_to_char_array
|
from .text import text_to_char_array
|
||||||
from util.flags import FLAGS
|
from .flags import FLAGS
|
||||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||||
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
||||||
from util.sample_collections import samples_from_files
|
from .sample_collections import samples_from_files
|
||||||
from util.helpers import remember_exception, MEGABYTE
|
from .helpers import remember_exception, MEGABYTE
|
||||||
|
|
||||||
|
|
||||||
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from util.helpers import secs_to_hours
|
from .helpers import secs_to_hours
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
def get_counter():
|
def get_counter():
|
|
@ -3,7 +3,7 @@ from __future__ import print_function
|
||||||
import progressbar
|
import progressbar
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from util.flags import FLAGS
|
from .flags import FLAGS
|
||||||
|
|
||||||
|
|
||||||
# Logging functions
|
# Logging functions
|
|
@ -5,8 +5,8 @@ import json
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from util.helpers import MEGABYTE, GIGABYTE, Interleaved
|
from .helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||||
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
||||||
|
|
||||||
BIG_ENDIAN = 'big'
|
BIG_ENDIAN = 'big'
|
||||||
INT_SIZE = 4
|
INT_SIZE = 4
|
|
@ -1,6 +1,7 @@
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
from util.sparse_image_warp import sparse_image_warp
|
|
||||||
|
from .sparse_image_warp import sparse_image_warp
|
||||||
|
|
||||||
def augment_freq_time_mask(spectrogram,
|
def augment_freq_time_mask(spectrogram,
|
||||||
frequency_masking_para=30,
|
frequency_masking_para=30,
|
|
@ -0,0 +1,167 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import print_function, absolute_import, division
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import errno
|
||||||
|
import gzip
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import six.moves.urllib as urllib
|
||||||
|
import stat
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pkg_resources import parse_version
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SCHEMES = {
|
||||||
|
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
|
||||||
|
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
|
||||||
|
}
|
||||||
|
|
||||||
|
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
|
||||||
|
|
||||||
|
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||||
|
assert arch_string is not None
|
||||||
|
assert artifact_name is not None
|
||||||
|
assert artifact_name
|
||||||
|
assert branch_name is not None
|
||||||
|
assert branch_name
|
||||||
|
|
||||||
|
return TASKCLUSTER_SCHEME % {'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||||
|
|
||||||
|
def maybe_download_tc(target_dir, tc_url, progress=True):
|
||||||
|
def report_progress(count, block_size, total_size):
|
||||||
|
percent = (count * block_size * 100) // total_size
|
||||||
|
sys.stdout.write("\rDownloading: %d%%" % percent)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
if percent >= 100:
|
||||||
|
print('\n')
|
||||||
|
|
||||||
|
assert target_dir is not None
|
||||||
|
|
||||||
|
target_dir = os.path.abspath(target_dir)
|
||||||
|
try:
|
||||||
|
os.makedirs(target_dir)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise e
|
||||||
|
assert os.path.isdir(os.path.dirname(target_dir))
|
||||||
|
|
||||||
|
tc_filename = os.path.basename(tc_url)
|
||||||
|
target_file = os.path.join(target_dir, tc_filename)
|
||||||
|
is_gzip = False
|
||||||
|
if not os.path.isfile(target_file):
|
||||||
|
print('Downloading %s ...' % tc_url)
|
||||||
|
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
|
||||||
|
is_gzip = headers.get('Content-Encoding') == 'gzip'
|
||||||
|
else:
|
||||||
|
print('File already exists: %s' % target_file)
|
||||||
|
|
||||||
|
if is_gzip:
|
||||||
|
with open(target_file, "r+b") as frw:
|
||||||
|
decompressed = gzip.decompress(frw.read())
|
||||||
|
frw.seek(0)
|
||||||
|
frw.write(decompressed)
|
||||||
|
frw.truncate()
|
||||||
|
|
||||||
|
return target_file
|
||||||
|
|
||||||
|
def maybe_download_tc_bin(**kwargs):
|
||||||
|
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
|
||||||
|
final_stat = os.stat(final_file)
|
||||||
|
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
|
||||||
|
|
||||||
|
def read(fname):
|
||||||
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||||
|
parser.add_argument('--target', required=False,
|
||||||
|
help='Where to put the native client binary files')
|
||||||
|
parser.add_argument('--arch', required=False,
|
||||||
|
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
|
||||||
|
parser.add_argument('--artifact', required=False,
|
||||||
|
default='native_client.tar.xz',
|
||||||
|
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
|
||||||
|
parser.add_argument('--source', required=False, default=None,
|
||||||
|
help='Name of the TaskCluster scheme to use.')
|
||||||
|
parser.add_argument('--branch', required=False,
|
||||||
|
help='Branch name to use. Defaulting to current content of VERSION file.')
|
||||||
|
parser.add_argument('--decoder', action='store_true',
|
||||||
|
help='Get URL to ds_ctcdecoder Python package.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.target and not args.decoder:
|
||||||
|
print('Pass either --target or --decoder.')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
is_arm = 'arm' in platform.machine()
|
||||||
|
is_mac = 'darwin' in sys.platform
|
||||||
|
is_64bit = sys.maxsize > (2**31 - 1)
|
||||||
|
is_ucs2 = sys.maxunicode < 0x10ffff
|
||||||
|
|
||||||
|
if not args.arch:
|
||||||
|
if is_arm:
|
||||||
|
args.arch = 'arm64' if is_64bit else 'arm'
|
||||||
|
elif is_mac:
|
||||||
|
args.arch = 'osx'
|
||||||
|
else:
|
||||||
|
args.arch = 'cpu'
|
||||||
|
|
||||||
|
if not args.branch:
|
||||||
|
version_string = read('../VERSION').strip()
|
||||||
|
ds_version = parse_version(version_string)
|
||||||
|
args.branch = "v{}".format(version_string)
|
||||||
|
else:
|
||||||
|
ds_version = parse_version(args.branch)
|
||||||
|
|
||||||
|
if args.decoder:
|
||||||
|
plat = platform.system().lower()
|
||||||
|
arch = platform.machine()
|
||||||
|
|
||||||
|
if plat == 'linux' and arch == 'x86_64':
|
||||||
|
plat = 'manylinux1'
|
||||||
|
|
||||||
|
if plat == 'darwin':
|
||||||
|
plat = 'macosx_10_10'
|
||||||
|
|
||||||
|
m_or_mu = 'mu' if is_ucs2 else 'm'
|
||||||
|
pyver = ''.join(map(str, sys.version_info[0:2]))
|
||||||
|
|
||||||
|
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
|
||||||
|
ds_version=ds_version,
|
||||||
|
pyver=pyver,
|
||||||
|
m_or_mu=m_or_mu,
|
||||||
|
platform=plat,
|
||||||
|
arch=arch
|
||||||
|
)
|
||||||
|
|
||||||
|
ctc_arch = args.arch + '-ctc'
|
||||||
|
|
||||||
|
print(get_tc_url(ctc_arch, artifact, args.branch))
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if args.source is not None:
|
||||||
|
if args.source in DEFAULT_SCHEMES:
|
||||||
|
global TASKCLUSTER_SCHEME
|
||||||
|
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
|
||||||
|
else:
|
||||||
|
print('No such scheme: %s' % args.source)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
|
||||||
|
|
||||||
|
if args.artifact == "convert_graphdef_memmapped_format":
|
||||||
|
convert_graph_file = os.path.join(args.target, args.artifact)
|
||||||
|
final_stat = os.stat(convert_graph_file)
|
||||||
|
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
|
||||||
|
|
||||||
|
if '.tar.' in args.artifact:
|
||||||
|
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -12,13 +12,13 @@ tflogging.set_verbosity(tflogging.ERROR)
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||||
|
|
||||||
from multiprocessing import Process, cpu_count
|
from deepspeech_training.util.audio import AudioFile
|
||||||
|
from deepspeech_training.util.config import Config, initialize_globals
|
||||||
|
from deepspeech_training.util.feeding import split_audio_file
|
||||||
|
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||||
|
from deepspeech_training.util.logging import log_error, log_info, log_progress, create_progressbar
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||||
from util.config import Config, initialize_globals
|
from multiprocessing import Process, cpu_count
|
||||||
from util.audio import AudioFile
|
|
||||||
from util.feeding import split_audio_file
|
|
||||||
from util.flags import create_flags, FLAGS
|
|
||||||
from util.logging import log_error, log_info, log_progress, create_progressbar
|
|
||||||
|
|
||||||
|
|
||||||
def fail(message, code=1):
|
def fail(message, code=1):
|
||||||
|
@ -27,8 +27,8 @@ def fail(message, code=1):
|
||||||
|
|
||||||
|
|
||||||
def transcribe_file(audio_path, tlog_path):
|
def transcribe_file(audio_path, tlog_path):
|
||||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
from util.checkpoints import load_or_init_graph
|
from deepspeech_training.util.checkpoints import load_or_init_graph
|
||||||
initialize_globals()
|
initialize_globals()
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,168 +1,12 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import print_function, absolute_import, division
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import errno
|
|
||||||
import stat
|
|
||||||
import gzip
|
|
||||||
|
|
||||||
import six.moves.urllib as urllib
|
|
||||||
|
|
||||||
from pkg_resources import parse_version
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SCHEMES = {
|
|
||||||
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
|
|
||||||
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
|
|
||||||
}
|
|
||||||
|
|
||||||
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
|
|
||||||
|
|
||||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
|
||||||
assert arch_string is not None
|
|
||||||
assert artifact_name is not None
|
|
||||||
assert artifact_name
|
|
||||||
assert branch_name is not None
|
|
||||||
assert branch_name
|
|
||||||
|
|
||||||
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
|
||||||
|
|
||||||
def maybe_download_tc(target_dir, tc_url, progress=True):
|
|
||||||
def report_progress(count, block_size, total_size):
|
|
||||||
percent = (count * block_size * 100) // total_size
|
|
||||||
sys.stdout.write("\rDownloading: %d%%" % percent)
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
if percent >= 100:
|
|
||||||
print('\n')
|
|
||||||
|
|
||||||
assert target_dir is not None
|
|
||||||
|
|
||||||
target_dir = os.path.abspath(target_dir)
|
|
||||||
try:
|
|
||||||
os.makedirs(target_dir)
|
|
||||||
except OSError as e:
|
|
||||||
if e.errno != errno.EEXIST:
|
|
||||||
raise e
|
|
||||||
assert os.path.isdir(os.path.dirname(target_dir))
|
|
||||||
|
|
||||||
tc_filename = os.path.basename(tc_url)
|
|
||||||
target_file = os.path.join(target_dir, tc_filename)
|
|
||||||
is_gzip = False
|
|
||||||
if not os.path.isfile(target_file):
|
|
||||||
print('Downloading %s ...' % tc_url)
|
|
||||||
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
|
|
||||||
is_gzip = headers.get('Content-Encoding') == 'gzip'
|
|
||||||
else:
|
|
||||||
print('File already exists: %s' % target_file)
|
|
||||||
|
|
||||||
if is_gzip:
|
|
||||||
with open(target_file, "r+b") as frw:
|
|
||||||
decompressed = gzip.decompress(frw.read())
|
|
||||||
frw.seek(0)
|
|
||||||
frw.write(decompressed)
|
|
||||||
frw.truncate()
|
|
||||||
|
|
||||||
return target_file
|
|
||||||
|
|
||||||
def maybe_download_tc_bin(**kwargs):
|
|
||||||
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
|
|
||||||
final_stat = os.stat(final_file)
|
|
||||||
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
|
|
||||||
|
|
||||||
def read(fname):
|
|
||||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
|
||||||
parser.add_argument('--target', required=False,
|
|
||||||
help='Where to put the native client binary files')
|
|
||||||
parser.add_argument('--arch', required=False,
|
|
||||||
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
|
|
||||||
parser.add_argument('--artifact', required=False,
|
|
||||||
default='native_client.tar.xz',
|
|
||||||
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
|
|
||||||
parser.add_argument('--source', required=False, default=None,
|
|
||||||
help='Name of the TaskCluster scheme to use.')
|
|
||||||
parser.add_argument('--branch', required=False,
|
|
||||||
help='Branch name to use. Defaulting to current content of VERSION file.')
|
|
||||||
parser.add_argument('--decoder', action='store_true',
|
|
||||||
help='Get URL to ds_ctcdecoder Python package.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.target and not args.decoder:
|
|
||||||
print('Pass either --target or --decoder.')
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
is_arm = 'arm' in platform.machine()
|
|
||||||
is_mac = 'darwin' in sys.platform
|
|
||||||
is_64bit = sys.maxsize > (2**31 - 1)
|
|
||||||
is_ucs2 = sys.maxunicode < 0x10ffff
|
|
||||||
|
|
||||||
if not args.arch:
|
|
||||||
if is_arm:
|
|
||||||
args.arch = 'arm64' if is_64bit else 'arm'
|
|
||||||
elif is_mac:
|
|
||||||
args.arch = 'osx'
|
|
||||||
else:
|
|
||||||
args.arch = 'cpu'
|
|
||||||
|
|
||||||
if not args.branch:
|
|
||||||
version_string = read('../VERSION').strip()
|
|
||||||
ds_version = parse_version(version_string)
|
|
||||||
args.branch = "v{}".format(version_string)
|
|
||||||
else:
|
|
||||||
ds_version = parse_version(args.branch)
|
|
||||||
|
|
||||||
if args.decoder:
|
|
||||||
plat = platform.system().lower()
|
|
||||||
arch = platform.machine()
|
|
||||||
|
|
||||||
if plat == 'linux' and arch == 'x86_64':
|
|
||||||
plat = 'manylinux1'
|
|
||||||
|
|
||||||
if plat == 'darwin':
|
|
||||||
plat = 'macosx_10_10'
|
|
||||||
|
|
||||||
m_or_mu = 'mu' if is_ucs2 else 'm'
|
|
||||||
pyver = ''.join(map(str, sys.version_info[0:2]))
|
|
||||||
|
|
||||||
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
|
|
||||||
ds_version=ds_version,
|
|
||||||
pyver=pyver,
|
|
||||||
m_or_mu=m_or_mu,
|
|
||||||
platform=plat,
|
|
||||||
arch=arch
|
|
||||||
)
|
|
||||||
|
|
||||||
ctc_arch = args.arch + '-ctc'
|
|
||||||
|
|
||||||
print(get_tc_url(ctc_arch, artifact, args.branch))
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
if args.source is not None:
|
|
||||||
if args.source in DEFAULT_SCHEMES:
|
|
||||||
global TASKCLUSTER_SCHEME
|
|
||||||
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
|
|
||||||
else:
|
|
||||||
print('No such scheme: %s' % args.source)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
|
|
||||||
|
|
||||||
if args.artifact == "convert_graphdef_memmapped_format":
|
|
||||||
convert_graph_file = os.path.join(args.target, args.artifact)
|
|
||||||
final_stat = os.stat(convert_graph_file)
|
|
||||||
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
|
|
||||||
|
|
||||||
if '.tar.' in args.artifact:
|
|
||||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
try:
|
||||||
|
from deepspeech_training.util import taskcluster as dsu_taskcluster
|
||||||
|
except ImportError:
|
||||||
|
print('Training package is not installed. See training documentation.')
|
||||||
|
raise
|
||||||
|
|
||||||
|
dsu_taskcluster.main()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче