diff --git a/align/align.py b/align/align.py index f27f31c..1b4da78 100644 --- a/align/align.py +++ b/align/align.py @@ -1,8 +1,10 @@ import sys import os import text +import json import logging import argparse +import os.path as path import numpy as np import wavTranscriber @@ -10,9 +12,11 @@ import wavTranscriber def main(args): parser = argparse.ArgumentParser(description='Force align speech data with a transcript.') parser.add_argument('audio', type=str, - help='Path to speech audio (WAV format)') + help='Source path of speech audio (WAV format)') parser.add_argument('transcript', type=str, - help='Path to original transcript') + help='Source path of original transcript (plain text)') + parser.add_argument('result', type=str, + help='Target path of alignment result file (JSON)') parser.add_argument('--aggressive', type=int, choices=range(4), required=False, help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)') parser.add_argument('--model', required=False, @@ -23,43 +27,70 @@ def main(args): # Debug helpers logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20) + logging.debug("Start") - # Loading model + fragments = [] + fragments_cache_path = args.result + '.cache' model_dir = os.path.expanduser(args.model if args.model else 'models/en') - output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir) - model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie) - alphabet = text.Alphabet(alphabet) + logging.debug("Looking for model files in %s..." % model_dir) + output_graph_path, alphabet_path, lm_path, trie_path = wavTranscriber.resolve_models(model_dir) + logging.debug("Loading alphabet from %s..." % alphabet_path) + alphabet = text.Alphabet(alphabet_path) - inference_time = 0.0 + if path.exists(fragments_cache_path): + logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path) + with open(fragments_cache_path, 'r') as result_file: + fragments = json.loads(result_file.read()) + else: + logging.debug("Loading model from %s..." % model_dir) + model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path) - # Run VAD on the input file - wave_file = args.audio - aggressiveness = int(args.aggressive) if args.aggressive else 3 - segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness) + inference_time = 0.0 + offset = 0 + # Run VAD on the input file + logging.debug("Transcribing VAD segments...") + wave_file = args.audio + aggressiveness = int(args.aggressive) if args.aggressive else 3 + segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness) + + for i, segment in enumerate(segments): + # Run DeepSpeech on the chunk that just completed VAD + segment_buffer, time_start, time_end = segment + logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0)) + + audio = np.frombuffer(segment_buffer, dtype=np.int16) + segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate) + if segment_transcript is None: + logging.debug("Segment %002d empty" % i) + continue + inference_time += segment_inference_time + fragments.append({ + 'time_start': time_start, + 'time_end': time_end, + 'transcript': segment_transcript, + 'offset': offset + }) + offset += len(segment_transcript) + + logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path) + with open(fragments_cache_path, 'w') as result_file: + result_file.write(json.dumps(fragments)) + + logging.debug("Loading original transcript from %s..." % args.transcript) with open(args.transcript, 'r') as transcript_file: original_transcript = transcript_file.read() original_transcript = ' '.join(original_transcript.lower().split()) original_transcript = alphabet.filter(original_transcript) - - position = 0 - - for i, segment in enumerate(segments): - # Run DeepSpeech on the chunk that just completed VAD - logging.debug("Transcribing segment %002d..." % i) - audio = np.frombuffer(segment, dtype=np.int16) - segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate) - inference_time += segment_inference_time - - logging.debug("Looking for segment transcript in original transcript...") - distance, found_offset, found_len = \ - text.minimal_distance(original_transcript, - segment_transcript, - start=position, - threshold=0.1) - logging.info("Segment transcript: %s" % segment_transcript) - logging.info("Segment found: %s" % original_transcript[found_offset:found_offset+found_len]) - logging.info("--") + ls = text.LevenshteinSearch(original_transcript) + start = 0 + for fragment in fragments: + logging.debug('STT Transcribed: %s' % fragment['transcript']) + match_distance, match_offset, match_len = ls.find_best(fragment['transcript']) + if match_offset >= 0: + fragment['original'] = original_transcript[match_offset:match_offset+match_len] + logging.debug(' Original: %s' % fragment['original']) + start = match_offset+match_len if __name__ == '__main__': diff --git a/align/text.py b/align/text.py index 1049bcb..affb1b5 100644 --- a/align/text.py +++ b/align/text.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function import codecs import logging +from nltk import ngrams from six.moves import range @@ -55,6 +56,57 @@ class Alphabet(object): return self._config_file +class TextCleaner: + def __init__(self, original_text, alphabet, to_lower=True, normalize_space=True): + self.original_text = original_text + clean_text = original_text + if to_lower: + clean_text = clean_text.lower() + if normalize_space: + clean_text = ' '.join(clean_text.split()) + self.clean_text = alphabet.filter(clean_text) + + def get_original_offset(self, clean_offset): + return clean_offset + + +class LevenshteinSearch: + def __init__(self, text): + self.text = text + self.ngrams = {} + for i, ngram in enumerate(ngrams(' ' + text + ' ', 3)): + if ngram in self.ngrams: + ngram_bucket = self.ngrams[ngram] + else: + ngram_bucket = self.ngrams[ngram] = [] + ngram_bucket.append(i) + + def find_best(self, look_for, start=0, stop=-1, threshold=0): + stop = len(self.text) if stop < 0 else stop + window_size = len(look_for) + windows = {} + for i, ngram in enumerate(ngrams(' ' + look_for + ' ', 3)): + if ngram in self.ngrams: + ngram_bucket = self.ngrams[ngram] + for occurrence in ngram_bucket: + if occurrence < start or occurrence > stop: + continue + window = occurrence // window_size + windows[window] = (windows[window] + 1) if window in windows else 1 + candidate_windows = sorted(windows.keys(), key=lambda w: windows[w], reverse=True) + found_best = False + best_distance = -1 + best_offset = -1 + best_len = -1 + for window in candidate_windows[0:4]: + for offset in range(int((window-0.5)*window_size), int((window+0.5)*window_size)): + distance = levenshtein(self.text[offset:offset + len(look_for)], look_for) + if not found_best or distance < best_distance: + found_best = True + best_distance = distance + best_offset = offset + best_len = len(look_for) + return best_distance, best_offset, best_len # The following code is from: http://hetland.org/coding/python/levenshtein.py @@ -88,25 +140,3 @@ def levenshtein(a, b): current[j] = min(add, delete, change) return current[n] - - -def minimal_distance(search_in, search_for, start=0, threshold=0): - best_distance = 1000000000 - best_offset = -1 - best_len = -1 - window = 10 - rough_acceptable_distance = int(1.5 * window) - acceptable_distance = int(len(search_for) * threshold) - stop = len(search_in)-len(search_for) - for rough_offset in range(start, stop, window): - rough_distance = levenshtein(search_in[rough_offset:rough_offset+len(search_for)], search_for) - if rough_distance < rough_acceptable_distance: - for offset in range(rough_offset-window, rough_offset+window, 1): - distance = levenshtein(search_in[offset:offset+len(search_for)], search_for) - if distance < best_distance: - best_distance = distance - best_offset = offset - best_len = len(search_for) - if best_distance <= acceptable_distance: - return best_distance, best_offset, best_len - return -1, 0, 0 diff --git a/align/wavSplit.py b/align/wavSplit.py index fbbbf08..4ae0bf9 100644 --- a/align/wavSplit.py +++ b/align/wavSplit.py @@ -94,7 +94,7 @@ def vad_collector(sample_rate, frame_duration_ms, triggered = False voiced_frames = [] - for frame in frames: + for frame_index, frame in enumerate(frames): is_speech = vad.is_speech(frame.bytes, sample_rate) if not triggered: @@ -122,13 +122,18 @@ def vad_collector(sample_rate, frame_duration_ms, # audio we've collected. if num_unvoiced > threshold * ring_buffer.maxlen: triggered = False - yield b''.join([f.bytes for f in voiced_frames]) + yield b''.join([f.bytes for f in voiced_frames]), \ + frame_duration_ms * (frame_index - len(voiced_frames)), \ + frame_duration_ms * frame_index ring_buffer.clear() voiced_frames = [] + if triggered: pass # If we have any leftover voiced audio when we run out of input, # yield it. if voiced_frames: - yield b''.join([f.bytes for f in voiced_frames]) + yield b''.join([f.bytes for f in voiced_frames]), \ + frame_duration_ms * (frame_index - len(voiced_frames)), \ + frame_duration_ms * (frame_index + 1) diff --git a/requirements.txt b/requirements.txt index d62ee81..0614599 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +nltk six deepspeech webrtcvad