This commit is contained in:
Tilman Kamp 2019-07-04 18:00:47 +02:00
Родитель c1c222c2db
Коммит 6e29fa594b
4 изменённых файлов: 122 добавлений и 55 удалений

Просмотреть файл

@ -1,8 +1,10 @@
import sys import sys
import os import os
import text import text
import json
import logging import logging
import argparse import argparse
import os.path as path
import numpy as np import numpy as np
import wavTranscriber import wavTranscriber
@ -10,9 +12,11 @@ import wavTranscriber
def main(args): def main(args):
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.') parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
parser.add_argument('audio', type=str, parser.add_argument('audio', type=str,
help='Path to speech audio (WAV format)') help='Source path of speech audio (WAV format)')
parser.add_argument('transcript', type=str, parser.add_argument('transcript', type=str,
help='Path to original transcript') help='Source path of original transcript (plain text)')
parser.add_argument('result', type=str,
help='Target path of alignment result file (JSON)')
parser.add_argument('--aggressive', type=int, choices=range(4), required=False, parser.add_argument('--aggressive', type=int, choices=range(4), required=False,
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)') help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
parser.add_argument('--model', required=False, parser.add_argument('--model', required=False,
@ -23,43 +27,70 @@ def main(args):
# Debug helpers # Debug helpers
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20) logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
logging.debug("Start")
# Loading model fragments = []
fragments_cache_path = args.result + '.cache'
model_dir = os.path.expanduser(args.model if args.model else 'models/en') model_dir = os.path.expanduser(args.model if args.model else 'models/en')
output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir) logging.debug("Looking for model files in %s..." % model_dir)
model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie) output_graph_path, alphabet_path, lm_path, trie_path = wavTranscriber.resolve_models(model_dir)
alphabet = text.Alphabet(alphabet) logging.debug("Loading alphabet from %s..." % alphabet_path)
alphabet = text.Alphabet(alphabet_path)
inference_time = 0.0 if path.exists(fragments_cache_path):
logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path)
with open(fragments_cache_path, 'r') as result_file:
fragments = json.loads(result_file.read())
else:
logging.debug("Loading model from %s..." % model_dir)
model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
# Run VAD on the input file inference_time = 0.0
wave_file = args.audio offset = 0
aggressiveness = int(args.aggressive) if args.aggressive else 3
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
# Run VAD on the input file
logging.debug("Transcribing VAD segments...")
wave_file = args.audio
aggressiveness = int(args.aggressive) if args.aggressive else 3
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
for i, segment in enumerate(segments):
# Run DeepSpeech on the chunk that just completed VAD
segment_buffer, time_start, time_end = segment
logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
audio = np.frombuffer(segment_buffer, dtype=np.int16)
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
if segment_transcript is None:
logging.debug("Segment %002d empty" % i)
continue
inference_time += segment_inference_time
fragments.append({
'time_start': time_start,
'time_end': time_end,
'transcript': segment_transcript,
'offset': offset
})
offset += len(segment_transcript)
logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path)
with open(fragments_cache_path, 'w') as result_file:
result_file.write(json.dumps(fragments))
logging.debug("Loading original transcript from %s..." % args.transcript)
with open(args.transcript, 'r') as transcript_file: with open(args.transcript, 'r') as transcript_file:
original_transcript = transcript_file.read() original_transcript = transcript_file.read()
original_transcript = ' '.join(original_transcript.lower().split()) original_transcript = ' '.join(original_transcript.lower().split())
original_transcript = alphabet.filter(original_transcript) original_transcript = alphabet.filter(original_transcript)
ls = text.LevenshteinSearch(original_transcript)
position = 0 start = 0
for fragment in fragments:
for i, segment in enumerate(segments): logging.debug('STT Transcribed: %s' % fragment['transcript'])
# Run DeepSpeech on the chunk that just completed VAD match_distance, match_offset, match_len = ls.find_best(fragment['transcript'])
logging.debug("Transcribing segment %002d..." % i) if match_offset >= 0:
audio = np.frombuffer(segment, dtype=np.int16) fragment['original'] = original_transcript[match_offset:match_offset+match_len]
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate) logging.debug(' Original: %s' % fragment['original'])
inference_time += segment_inference_time start = match_offset+match_len
logging.debug("Looking for segment transcript in original transcript...")
distance, found_offset, found_len = \
text.minimal_distance(original_transcript,
segment_transcript,
start=position,
threshold=0.1)
logging.info("Segment transcript: %s" % segment_transcript)
logging.info("Segment found: %s" % original_transcript[found_offset:found_offset+found_len])
logging.info("--")
if __name__ == '__main__': if __name__ == '__main__':

Просмотреть файл

@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function
import codecs import codecs
import logging import logging
from nltk import ngrams
from six.moves import range from six.moves import range
@ -55,6 +56,57 @@ class Alphabet(object):
return self._config_file return self._config_file
class TextCleaner:
def __init__(self, original_text, alphabet, to_lower=True, normalize_space=True):
self.original_text = original_text
clean_text = original_text
if to_lower:
clean_text = clean_text.lower()
if normalize_space:
clean_text = ' '.join(clean_text.split())
self.clean_text = alphabet.filter(clean_text)
def get_original_offset(self, clean_offset):
return clean_offset
class LevenshteinSearch:
def __init__(self, text):
self.text = text
self.ngrams = {}
for i, ngram in enumerate(ngrams(' ' + text + ' ', 3)):
if ngram in self.ngrams:
ngram_bucket = self.ngrams[ngram]
else:
ngram_bucket = self.ngrams[ngram] = []
ngram_bucket.append(i)
def find_best(self, look_for, start=0, stop=-1, threshold=0):
stop = len(self.text) if stop < 0 else stop
window_size = len(look_for)
windows = {}
for i, ngram in enumerate(ngrams(' ' + look_for + ' ', 3)):
if ngram in self.ngrams:
ngram_bucket = self.ngrams[ngram]
for occurrence in ngram_bucket:
if occurrence < start or occurrence > stop:
continue
window = occurrence // window_size
windows[window] = (windows[window] + 1) if window in windows else 1
candidate_windows = sorted(windows.keys(), key=lambda w: windows[w], reverse=True)
found_best = False
best_distance = -1
best_offset = -1
best_len = -1
for window in candidate_windows[0:4]:
for offset in range(int((window-0.5)*window_size), int((window+0.5)*window_size)):
distance = levenshtein(self.text[offset:offset + len(look_for)], look_for)
if not found_best or distance < best_distance:
found_best = True
best_distance = distance
best_offset = offset
best_len = len(look_for)
return best_distance, best_offset, best_len
# The following code is from: http://hetland.org/coding/python/levenshtein.py # The following code is from: http://hetland.org/coding/python/levenshtein.py
@ -88,25 +140,3 @@ def levenshtein(a, b):
current[j] = min(add, delete, change) current[j] = min(add, delete, change)
return current[n] return current[n]
def minimal_distance(search_in, search_for, start=0, threshold=0):
best_distance = 1000000000
best_offset = -1
best_len = -1
window = 10
rough_acceptable_distance = int(1.5 * window)
acceptable_distance = int(len(search_for) * threshold)
stop = len(search_in)-len(search_for)
for rough_offset in range(start, stop, window):
rough_distance = levenshtein(search_in[rough_offset:rough_offset+len(search_for)], search_for)
if rough_distance < rough_acceptable_distance:
for offset in range(rough_offset-window, rough_offset+window, 1):
distance = levenshtein(search_in[offset:offset+len(search_for)], search_for)
if distance < best_distance:
best_distance = distance
best_offset = offset
best_len = len(search_for)
if best_distance <= acceptable_distance:
return best_distance, best_offset, best_len
return -1, 0, 0

Просмотреть файл

@ -94,7 +94,7 @@ def vad_collector(sample_rate, frame_duration_ms,
triggered = False triggered = False
voiced_frames = [] voiced_frames = []
for frame in frames: for frame_index, frame in enumerate(frames):
is_speech = vad.is_speech(frame.bytes, sample_rate) is_speech = vad.is_speech(frame.bytes, sample_rate)
if not triggered: if not triggered:
@ -122,13 +122,18 @@ def vad_collector(sample_rate, frame_duration_ms,
# audio we've collected. # audio we've collected.
if num_unvoiced > threshold * ring_buffer.maxlen: if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False triggered = False
yield b''.join([f.bytes for f in voiced_frames]) yield b''.join([f.bytes for f in voiced_frames]), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * frame_index
ring_buffer.clear() ring_buffer.clear()
voiced_frames = [] voiced_frames = []
if triggered: if triggered:
pass pass
# If we have any leftover voiced audio when we run out of input, # If we have any leftover voiced audio when we run out of input,
# yield it. # yield it.
if voiced_frames: if voiced_frames:
yield b''.join([f.bytes for f in voiced_frames]) yield b''.join([f.bytes for f in voiced_frames]), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * (frame_index + 1)

Просмотреть файл

@ -1,3 +1,4 @@
nltk
six six
deepspeech deepspeech
webrtcvad webrtcvad