This commit is contained in:
Tilman Kamp 2019-06-26 19:04:37 +02:00
Родитель 407a6d26d1
Коммит f5d5e7a0dc
4 изменённых файлов: 165 добавлений и 76 удалений

Просмотреть файл

@ -5,9 +5,6 @@ import argparse
import numpy as np
import wavTranscriber
# Debug helpers
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
def main(args):
parser = argparse.ArgumentParser(description='Transcribe long audio files using webRTC VAD or use the streaming interface')
@ -17,15 +14,17 @@ def main(args):
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
parser.add_argument('--model', required=False,
help='Path to directory that contains all model files (output_graph, lm, trie and alphabet)')
parser.add_argument('--loglevel', type=int, required=False,
help='Log level (between 0 and 50) - default: 20')
args = parser.parse_args()
# Debug helpers
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
# Loading model
model_dir = os.path.expanduser(args.model if args.model else 'models/en')
output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir)
model = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
title_names = ['Filename', 'Duration(s)', 'Inference Time(s)', 'Model Load Time(s)', 'LM Load Time(s)']
print("\n%-30s %-20s %-20s %-20s %s" % (title_names[0], title_names[1], title_names[2], title_names[3], title_names[4]))
model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
inference_time = 0.0
@ -37,26 +36,17 @@ def main(args):
logging.debug("Saving Transcript @: %s" % wave_file.rstrip(".wav") + ".txt")
for i, segment in enumerate(segments):
# Run deepspeech on the chunk that just completed VAD
# Run DeepSpeech on the chunk that just completed VAD
logging.debug("Processing chunk %002d" % (i,))
audio = np.frombuffer(segment, dtype=np.int16)
output = wavTranscriber.stt(model[0], audio, sample_rate)
inference_time += output[1]
logging.debug("Transcript: %s" % output[0])
transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
inference_time += segment_inference_time
logging.info("Transcript: %s" % transcript)
f.write(output[0] + " ")
f.write(transcript + " ")
# Summary of the files processed
f.close()
# Extract filename from the full file path
filename, ext = os.path.split(os.path.basename(wave_file))
logging.debug("************************************************************************************************************")
logging.debug("%-30s %-20s %-20s %-20s %s" % (title_names[0], title_names[1], title_names[2], title_names[3], title_names[4]))
logging.debug("%-30s %-20.3f %-20.3f %-20.3f %-0.3f" % (filename + ext, audio_length, inference_time, model_retval[1], model_retval[2]))
logging.debug("************************************************************************************************************")
print("%-30s %-20.3f %-20.3f %-20.3f %-0.3f" % (filename + ext, audio_length, inference_time, model_retval[1], model_retval[2]))
if __name__ == '__main__':
main(sys.argv[1:])

107
align/text.py Normal file
Просмотреть файл

@ -0,0 +1,107 @@
from __future__ import absolute_import, division, print_function
import codecs
import re
import numpy as np
from six.moves import range
class Alphabet(object):
def __init__(self, config_file):
self._config_file = config_file
self._label_to_str = []
self._str_to_label = {}
self._size = 0
with codecs.open(config_file, 'r', 'utf-8') as fin:
for line in fin:
if line[0:2] == '\\#':
line = '#\n'
elif line[0] == '#':
continue
self._label_to_str += line[:-1] # remove the line ending
self._str_to_label[line[:-1]] = self._size
self._size += 1
def string_from_label(self, label):
return self._label_to_str[label]
def label_from_string(self, string):
try:
return self._str_to_label[string]
except KeyError as e:
raise KeyError(
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
).with_traceback(e.__traceback__)
def decode(self, labels):
res = ''
for label in labels:
res += self.string_from_label(label)
return res
def size(self):
return self._size
def config_file(self):
return self._config_file
def text_to_char_array(original, alphabet):
"""
Given a Python string ``original``, remove unsupported characters, map characters
to integers and return a numpy array representing the processed string.
"""
return np.asarray([alphabet.label_from_string(c) for c in original])
# Validate and normalize transcriptions. Returns a cleaned version of the label
# or None if it's invalid.
def validate_label(label):
# For now we can only handle [a-z ']
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
return None
label = label.replace("-", "")
label = label.replace("_", "")
label = label.replace(".", "")
label = label.replace(",", "")
label = label.replace("?", "")
label = label.replace("\"", "")
label = label.strip()
label = label.lower()
return label if label else None
# The following code is from: http://hetland.org/coding/python/levenshtein.py
# This is a straightforward implementation of a well-known algorithm, and thus
# probably shouldn't be covered by copyright to begin with. But in case it is,
# the author (Magnus Lie Hetland) has, to the extent possible under law,
# dedicated all copyright and related and neighboring rights to this software
# to the public domain worldwide, by distributing it under the CC0 license,
# version 1.0. This software is distributed without any warranty. For more
# information, see <http://creativecommons.org/publicdomain/zero/1.0>
def levenshtein(a, b):
"""
Calculates the Levenshtein distance between a and b.
"""
n, m = len(a), len(b)
if n > m:
# Make sure n <= m, to use O(min(n,m)) space
a, b = b, a
n, m = m, n
current = list(range(n+1))
for i in range(1, m+1):
previous, current = current, [i]+[0]*n
for j in range(1, n+1):
add, delete = previous[j]+1, current[j-1]+1
change = previous[j-1]
if a[j-1] != b[i-1]:
change = change + 1
current[j] = min(add, delete, change)
return current[n]

Просмотреть файл

@ -60,7 +60,7 @@ def frame_generator(frame_duration_ms, audio, sample_rate):
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, vad, frames):
padding_duration_ms, threshold, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
@ -103,7 +103,7 @@ def vad_collector(sample_rate, frame_duration_ms,
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * ring_buffer.maxlen:
if num_voiced > threshold * ring_buffer.maxlen:
triggered = True
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
@ -120,7 +120,7 @@ def vad_collector(sample_rate, frame_duration_ms,
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen:
if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False
yield b''.join([f.bytes for f in voiced_frames])
ring_buffer.clear()

Просмотреть файл

@ -5,17 +5,16 @@ import wavSplit
from deepspeech import Model
from timeit import default_timer as timer
'''
Load the pre-trained model into the memory
@param models: Output Grapgh Protocol Buffer file
@param alphabet: Alphabet.txt file
@param lm: Language model file
@param trie: Trie file
@Retval
Returns a list [DeepSpeech Object, Model Load Time, LM Load Time]
'''
def load_model(models, alphabet, lm, trie):
"""
Load the pre-trained model into the memory
:param models: Output Graph Protocol Buffer file
:param alphabet: Alphabet.txt file
:param lm: Language model file
:param trie: Trie file
:return: tuple (DeepSpeech object, Model Load Time, LM Load Time)
"""
N_FEATURES = 26
N_CONTEXT = 9
BEAM_WIDTH = 500
@ -32,72 +31,65 @@ def load_model(models, alphabet, lm, trie):
lm_load_end = timer() - lm_load_start
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))
return [ds, model_load_end, lm_load_end]
return ds, model_load_end, lm_load_end
'''
Run Inference on input audio file
@param ds: Deepspeech object
@param audio: Input audio for running inference on
@param fs: Sample rate of the input audio file
@Retval:
Returns a list [Inference, Inference Time, Audio Length]
'''
def stt(ds, audio, fs):
inference_time = 0.0
"""
Run Inference on input audio file
:param ds: DeepSpeech object
:param audio: Input audio for running inference on
:param fs: Sample rate of the input audio file
:return: tuple (Inference result text, Inference time)
"""
audio_length = len(audio) * (1 / 16000)
# Run Deepspeech
# Run DeepSpeech
logging.debug('Running inference...')
inference_start = timer()
output = ds.stt(audio, fs)
inference_end = timer() - inference_start
inference_time += inference_end
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length))
inference_time = timer() - inference_start
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_time, audio_length))
return output, inference_time
return [output, inference_time]
'''
Resolve directory path for the models and fetch each of them.
@param dirName: Path to the directory containing pre-trained models
@Retval:
Retunns a tuple containing each of the model files (pb, alphabet, lm and trie)
'''
def resolve_models(dirName):
pb = glob.glob(dirName + "/*.pb")[0]
def resolve_models(dir_name):
"""
Resolve directory path for the models and fetch each of them.
:param dir_name: Path to the directory containing pre-trained models
:return: tuple containing each of the model files (pb, alphabet, lm and trie)
"""
pb = glob.glob(dir_name + "/*.pb")[0]
logging.debug("Found Model: %s" % pb)
alphabet = glob.glob(dirName + "/alphabet.txt")[0]
alphabet = glob.glob(dir_name + "/alphabet.txt")[0]
logging.debug("Found Alphabet: %s" % alphabet)
lm = glob.glob(dirName + "/lm.binary")[0]
trie = glob.glob(dirName + "/trie")[0]
lm = glob.glob(dir_name + "/lm.binary")[0]
trie = glob.glob(dir_name + "/trie")[0]
logging.debug("Found Language Model: %s" % lm)
logging.debug("Found Trie: %s" % trie)
return pb, alphabet, lm, trie
'''
Generate VAD segments. Filters out non-voiced audio frames.
@param waveFile: Input wav file to run VAD on.0
@Retval:
Returns tuple of
segments: a bytearray of multiple smaller audio frames
(The longer audio split into mutiple smaller one's)
sample_rate: Sample rate of the input audio file
audio_length: Duraton of the input audio file
'''
def vad_segment_generator(wavFile, aggressiveness):
logging.debug("Caught the wav file @: %s" % (wavFile))
audio, sample_rate, audio_length = wavSplit.read_wave(wavFile)
def vad_segment_generator(wav_file, aggressiveness):
"""
Generate VAD segments. Filters out non-voiced audio frames.
:param wav_file: Input wav file to run VAD on.0
:param aggressiveness: How aggressive filtering out non-speech is (between 0 and 3)
:return: Returns tuple of
segments: a bytearray of multiple smaller audio frames
(The longer audio split into multiple smaller one's)
sample_rate: Sample rate of the input audio file
audio_length: Duration of the input audio file
"""
logging.debug("Caught the wav file @: %s" % wav_file)
audio, sample_rate, audio_length = wavSplit.read_wave(wav_file)
assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!"
vad = webrtcvad.Vad(int(aggressiveness))
frames = wavSplit.frame_generator(30, audio, sample_rate)
frames = list(frames)
segments = wavSplit.vad_collector(sample_rate, 30, 300, vad, frames)
segments = wavSplit.vad_collector(sample_rate, 30, 300, 0.5, vad, frames)
return segments, sample_rate, audio_length