From f5d5e7a0dce111115ee2c8361993ad50850b1f5e Mon Sep 17 00:00:00 2001 From: Tilman Kamp <5991088+tilmankamp@users.noreply.github.com> Date: Wed, 26 Jun 2019 19:04:37 +0200 Subject: [PATCH] WIP --- align/align.py | 32 +++++------- align/text.py | 107 ++++++++++++++++++++++++++++++++++++++++ align/wavSplit.py | 6 +-- align/wavTranscriber.py | 96 +++++++++++++++++------------------ 4 files changed, 165 insertions(+), 76 deletions(-) create mode 100644 align/text.py diff --git a/align/align.py b/align/align.py index f6e7dea..453f6dc 100644 --- a/align/align.py +++ b/align/align.py @@ -5,9 +5,6 @@ import argparse import numpy as np import wavTranscriber -# Debug helpers -logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) - def main(args): parser = argparse.ArgumentParser(description='Transcribe long audio files using webRTC VAD or use the streaming interface') @@ -17,15 +14,17 @@ def main(args): help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)') parser.add_argument('--model', required=False, help='Path to directory that contains all model files (output_graph, lm, trie and alphabet)') + parser.add_argument('--loglevel', type=int, required=False, + help='Log level (between 0 and 50) - default: 20') args = parser.parse_args() + # Debug helpers + logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20) + # Loading model model_dir = os.path.expanduser(args.model if args.model else 'models/en') output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir) - model = wavTranscriber.load_model(output_graph, alphabet, lm, trie) - - title_names = ['Filename', 'Duration(s)', 'Inference Time(s)', 'Model Load Time(s)', 'LM Load Time(s)'] - print("\n%-30s %-20s %-20s %-20s %s" % (title_names[0], title_names[1], title_names[2], title_names[3], title_names[4])) + model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie) inference_time = 0.0 @@ -37,26 +36,17 @@ def main(args): logging.debug("Saving Transcript @: %s" % wave_file.rstrip(".wav") + ".txt") for i, segment in enumerate(segments): - # Run deepspeech on the chunk that just completed VAD + # Run DeepSpeech on the chunk that just completed VAD logging.debug("Processing chunk %002d" % (i,)) audio = np.frombuffer(segment, dtype=np.int16) - output = wavTranscriber.stt(model[0], audio, sample_rate) - inference_time += output[1] - logging.debug("Transcript: %s" % output[0]) + transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate) + inference_time += segment_inference_time + logging.info("Transcript: %s" % transcript) - f.write(output[0] + " ") + f.write(transcript + " ") # Summary of the files processed f.close() - # Extract filename from the full file path - filename, ext = os.path.split(os.path.basename(wave_file)) - logging.debug("************************************************************************************************************") - logging.debug("%-30s %-20s %-20s %-20s %s" % (title_names[0], title_names[1], title_names[2], title_names[3], title_names[4])) - logging.debug("%-30s %-20.3f %-20.3f %-20.3f %-0.3f" % (filename + ext, audio_length, inference_time, model_retval[1], model_retval[2])) - logging.debug("************************************************************************************************************") - print("%-30s %-20.3f %-20.3f %-20.3f %-0.3f" % (filename + ext, audio_length, inference_time, model_retval[1], model_retval[2])) - - if __name__ == '__main__': main(sys.argv[1:]) diff --git a/align/text.py b/align/text.py new file mode 100644 index 0000000..6976d41 --- /dev/null +++ b/align/text.py @@ -0,0 +1,107 @@ +from __future__ import absolute_import, division, print_function + +import codecs +import re + +import numpy as np + +from six.moves import range + +class Alphabet(object): + def __init__(self, config_file): + self._config_file = config_file + self._label_to_str = [] + self._str_to_label = {} + self._size = 0 + with codecs.open(config_file, 'r', 'utf-8') as fin: + for line in fin: + if line[0:2] == '\\#': + line = '#\n' + elif line[0] == '#': + continue + self._label_to_str += line[:-1] # remove the line ending + self._str_to_label[line[:-1]] = self._size + self._size += 1 + + def string_from_label(self, label): + return self._label_to_str[label] + + def label_from_string(self, string): + try: + return self._str_to_label[string] + except KeyError as e: + raise KeyError( + '''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.''' + ).with_traceback(e.__traceback__) + + def decode(self, labels): + res = '' + for label in labels: + res += self.string_from_label(label) + return res + + def size(self): + return self._size + + def config_file(self): + return self._config_file + + +def text_to_char_array(original, alphabet): + """ + Given a Python string ``original``, remove unsupported characters, map characters + to integers and return a numpy array representing the processed string. + """ + return np.asarray([alphabet.label_from_string(c) for c in original]) + + +# Validate and normalize transcriptions. Returns a cleaned version of the label +# or None if it's invalid. +def validate_label(label): + # For now we can only handle [a-z '] + if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None: + return None + + label = label.replace("-", "") + label = label.replace("_", "") + label = label.replace(".", "") + label = label.replace(",", "") + label = label.replace("?", "") + label = label.replace("\"", "") + label = label.strip() + label = label.lower() + + return label if label else None + + +# The following code is from: http://hetland.org/coding/python/levenshtein.py + +# This is a straightforward implementation of a well-known algorithm, and thus +# probably shouldn't be covered by copyright to begin with. But in case it is, +# the author (Magnus Lie Hetland) has, to the extent possible under law, +# dedicated all copyright and related and neighboring rights to this software +# to the public domain worldwide, by distributing it under the CC0 license, +# version 1.0. This software is distributed without any warranty. For more +# information, see + +def levenshtein(a, b): + """ + Calculates the Levenshtein distance between a and b. + """ + n, m = len(a), len(b) + if n > m: + # Make sure n <= m, to use O(min(n,m)) space + a, b = b, a + n, m = m, n + + current = list(range(n+1)) + for i in range(1, m+1): + previous, current = current, [i]+[0]*n + for j in range(1, n+1): + add, delete = previous[j]+1, current[j-1]+1 + change = previous[j-1] + if a[j-1] != b[i-1]: + change = change + 1 + current[j] = min(add, delete, change) + + return current[n] diff --git a/align/wavSplit.py b/align/wavSplit.py index 44aa573..fbbbf08 100644 --- a/align/wavSplit.py +++ b/align/wavSplit.py @@ -60,7 +60,7 @@ def frame_generator(frame_duration_ms, audio, sample_rate): def vad_collector(sample_rate, frame_duration_ms, - padding_duration_ms, vad, frames): + padding_duration_ms, threshold, vad, frames): """Filters out non-voiced audio frames. Given a webrtcvad.Vad and a source of audio frames, yields only @@ -103,7 +103,7 @@ def vad_collector(sample_rate, frame_duration_ms, # If we're NOTTRIGGERED and more than 90% of the frames in # the ring buffer are voiced frames, then enter the # TRIGGERED state. - if num_voiced > 0.9 * ring_buffer.maxlen: + if num_voiced > threshold * ring_buffer.maxlen: triggered = True # We want to yield all the audio we see from now until # we are NOTTRIGGERED, but we have to start with the @@ -120,7 +120,7 @@ def vad_collector(sample_rate, frame_duration_ms, # If more than 90% of the frames in the ring buffer are # unvoiced, then enter NOTTRIGGERED and yield whatever # audio we've collected. - if num_unvoiced > 0.9 * ring_buffer.maxlen: + if num_unvoiced > threshold * ring_buffer.maxlen: triggered = False yield b''.join([f.bytes for f in voiced_frames]) ring_buffer.clear() diff --git a/align/wavTranscriber.py b/align/wavTranscriber.py index 2735879..0db3024 100644 --- a/align/wavTranscriber.py +++ b/align/wavTranscriber.py @@ -5,17 +5,16 @@ import wavSplit from deepspeech import Model from timeit import default_timer as timer -''' -Load the pre-trained model into the memory -@param models: Output Grapgh Protocol Buffer file -@param alphabet: Alphabet.txt file -@param lm: Language model file -@param trie: Trie file -@Retval -Returns a list [DeepSpeech Object, Model Load Time, LM Load Time] -''' def load_model(models, alphabet, lm, trie): + """ + Load the pre-trained model into the memory + :param models: Output Graph Protocol Buffer file + :param alphabet: Alphabet.txt file + :param lm: Language model file + :param trie: Trie file + :return: tuple (DeepSpeech object, Model Load Time, LM Load Time) + """ N_FEATURES = 26 N_CONTEXT = 9 BEAM_WIDTH = 500 @@ -32,72 +31,65 @@ def load_model(models, alphabet, lm, trie): lm_load_end = timer() - lm_load_start logging.debug('Loaded language model in %0.3fs.' % (lm_load_end)) - return [ds, model_load_end, lm_load_end] + return ds, model_load_end, lm_load_end -''' -Run Inference on input audio file -@param ds: Deepspeech object -@param audio: Input audio for running inference on -@param fs: Sample rate of the input audio file -@Retval: -Returns a list [Inference, Inference Time, Audio Length] - -''' def stt(ds, audio, fs): - inference_time = 0.0 + """ + Run Inference on input audio file + :param ds: DeepSpeech object + :param audio: Input audio for running inference on + :param fs: Sample rate of the input audio file + :return: tuple (Inference result text, Inference time) + """ audio_length = len(audio) * (1 / 16000) - # Run Deepspeech + # Run DeepSpeech logging.debug('Running inference...') inference_start = timer() output = ds.stt(audio, fs) - inference_end = timer() - inference_start - inference_time += inference_end - logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length)) + inference_time = timer() - inference_start + logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_time, audio_length)) + return output, inference_time - return [output, inference_time] -''' -Resolve directory path for the models and fetch each of them. -@param dirName: Path to the directory containing pre-trained models - -@Retval: -Retunns a tuple containing each of the model files (pb, alphabet, lm and trie) -''' -def resolve_models(dirName): - pb = glob.glob(dirName + "/*.pb")[0] +def resolve_models(dir_name): + """ + Resolve directory path for the models and fetch each of them. + :param dir_name: Path to the directory containing pre-trained models + :return: tuple containing each of the model files (pb, alphabet, lm and trie) + """ + pb = glob.glob(dir_name + "/*.pb")[0] logging.debug("Found Model: %s" % pb) - alphabet = glob.glob(dirName + "/alphabet.txt")[0] + alphabet = glob.glob(dir_name + "/alphabet.txt")[0] logging.debug("Found Alphabet: %s" % alphabet) - lm = glob.glob(dirName + "/lm.binary")[0] - trie = glob.glob(dirName + "/trie")[0] + lm = glob.glob(dir_name + "/lm.binary")[0] + trie = glob.glob(dir_name + "/trie")[0] logging.debug("Found Language Model: %s" % lm) logging.debug("Found Trie: %s" % trie) return pb, alphabet, lm, trie -''' -Generate VAD segments. Filters out non-voiced audio frames. -@param waveFile: Input wav file to run VAD on.0 -@Retval: -Returns tuple of - segments: a bytearray of multiple smaller audio frames - (The longer audio split into mutiple smaller one's) - sample_rate: Sample rate of the input audio file - audio_length: Duraton of the input audio file - -''' -def vad_segment_generator(wavFile, aggressiveness): - logging.debug("Caught the wav file @: %s" % (wavFile)) - audio, sample_rate, audio_length = wavSplit.read_wave(wavFile) +def vad_segment_generator(wav_file, aggressiveness): + """ + Generate VAD segments. Filters out non-voiced audio frames. + :param wav_file: Input wav file to run VAD on.0 + :param aggressiveness: How aggressive filtering out non-speech is (between 0 and 3) + :return: Returns tuple of + segments: a bytearray of multiple smaller audio frames + (The longer audio split into multiple smaller one's) + sample_rate: Sample rate of the input audio file + audio_length: Duration of the input audio file + """ + logging.debug("Caught the wav file @: %s" % wav_file) + audio, sample_rate, audio_length = wavSplit.read_wave(wav_file) assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!" vad = webrtcvad.Vad(int(aggressiveness)) frames = wavSplit.frame_generator(30, audio, sample_rate) frames = list(frames) - segments = wavSplit.vad_collector(sample_rate, 30, 300, vad, frames) + segments = wavSplit.vad_collector(sample_rate, 30, 300, 0.5, vad, frames) return segments, sample_rate, audio_length