This commit is contained in:
Tilman Kamp 2019-07-04 18:00:47 +02:00
Родитель c1c222c2db
Коммит 6e29fa594b
4 изменённых файлов: 122 добавлений и 55 удалений

Просмотреть файл

@ -1,8 +1,10 @@
import sys
import os
import text
import json
import logging
import argparse
import os.path as path
import numpy as np
import wavTranscriber
@ -10,9 +12,11 @@ import wavTranscriber
def main(args):
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
parser.add_argument('audio', type=str,
help='Path to speech audio (WAV format)')
help='Source path of speech audio (WAV format)')
parser.add_argument('transcript', type=str,
help='Path to original transcript')
help='Source path of original transcript (plain text)')
parser.add_argument('result', type=str,
help='Target path of alignment result file (JSON)')
parser.add_argument('--aggressive', type=int, choices=range(4), required=False,
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
parser.add_argument('--model', required=False,
@ -23,43 +27,70 @@ def main(args):
# Debug helpers
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
logging.debug("Start")
# Loading model
fragments = []
fragments_cache_path = args.result + '.cache'
model_dir = os.path.expanduser(args.model if args.model else 'models/en')
output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir)
model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
alphabet = text.Alphabet(alphabet)
logging.debug("Looking for model files in %s..." % model_dir)
output_graph_path, alphabet_path, lm_path, trie_path = wavTranscriber.resolve_models(model_dir)
logging.debug("Loading alphabet from %s..." % alphabet_path)
alphabet = text.Alphabet(alphabet_path)
inference_time = 0.0
if path.exists(fragments_cache_path):
logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path)
with open(fragments_cache_path, 'r') as result_file:
fragments = json.loads(result_file.read())
else:
logging.debug("Loading model from %s..." % model_dir)
model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
# Run VAD on the input file
wave_file = args.audio
aggressiveness = int(args.aggressive) if args.aggressive else 3
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
inference_time = 0.0
offset = 0
# Run VAD on the input file
logging.debug("Transcribing VAD segments...")
wave_file = args.audio
aggressiveness = int(args.aggressive) if args.aggressive else 3
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
for i, segment in enumerate(segments):
# Run DeepSpeech on the chunk that just completed VAD
segment_buffer, time_start, time_end = segment
logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
audio = np.frombuffer(segment_buffer, dtype=np.int16)
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
if segment_transcript is None:
logging.debug("Segment %002d empty" % i)
continue
inference_time += segment_inference_time
fragments.append({
'time_start': time_start,
'time_end': time_end,
'transcript': segment_transcript,
'offset': offset
})
offset += len(segment_transcript)
logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path)
with open(fragments_cache_path, 'w') as result_file:
result_file.write(json.dumps(fragments))
logging.debug("Loading original transcript from %s..." % args.transcript)
with open(args.transcript, 'r') as transcript_file:
original_transcript = transcript_file.read()
original_transcript = ' '.join(original_transcript.lower().split())
original_transcript = alphabet.filter(original_transcript)
position = 0
for i, segment in enumerate(segments):
# Run DeepSpeech on the chunk that just completed VAD
logging.debug("Transcribing segment %002d..." % i)
audio = np.frombuffer(segment, dtype=np.int16)
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
inference_time += segment_inference_time
logging.debug("Looking for segment transcript in original transcript...")
distance, found_offset, found_len = \
text.minimal_distance(original_transcript,
segment_transcript,
start=position,
threshold=0.1)
logging.info("Segment transcript: %s" % segment_transcript)
logging.info("Segment found: %s" % original_transcript[found_offset:found_offset+found_len])
logging.info("--")
ls = text.LevenshteinSearch(original_transcript)
start = 0
for fragment in fragments:
logging.debug('STT Transcribed: %s' % fragment['transcript'])
match_distance, match_offset, match_len = ls.find_best(fragment['transcript'])
if match_offset >= 0:
fragment['original'] = original_transcript[match_offset:match_offset+match_len]
logging.debug(' Original: %s' % fragment['original'])
start = match_offset+match_len
if __name__ == '__main__':

Просмотреть файл

@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function
import codecs
import logging
from nltk import ngrams
from six.moves import range
@ -55,6 +56,57 @@ class Alphabet(object):
return self._config_file
class TextCleaner:
def __init__(self, original_text, alphabet, to_lower=True, normalize_space=True):
self.original_text = original_text
clean_text = original_text
if to_lower:
clean_text = clean_text.lower()
if normalize_space:
clean_text = ' '.join(clean_text.split())
self.clean_text = alphabet.filter(clean_text)
def get_original_offset(self, clean_offset):
return clean_offset
class LevenshteinSearch:
def __init__(self, text):
self.text = text
self.ngrams = {}
for i, ngram in enumerate(ngrams(' ' + text + ' ', 3)):
if ngram in self.ngrams:
ngram_bucket = self.ngrams[ngram]
else:
ngram_bucket = self.ngrams[ngram] = []
ngram_bucket.append(i)
def find_best(self, look_for, start=0, stop=-1, threshold=0):
stop = len(self.text) if stop < 0 else stop
window_size = len(look_for)
windows = {}
for i, ngram in enumerate(ngrams(' ' + look_for + ' ', 3)):
if ngram in self.ngrams:
ngram_bucket = self.ngrams[ngram]
for occurrence in ngram_bucket:
if occurrence < start or occurrence > stop:
continue
window = occurrence // window_size
windows[window] = (windows[window] + 1) if window in windows else 1
candidate_windows = sorted(windows.keys(), key=lambda w: windows[w], reverse=True)
found_best = False
best_distance = -1
best_offset = -1
best_len = -1
for window in candidate_windows[0:4]:
for offset in range(int((window-0.5)*window_size), int((window+0.5)*window_size)):
distance = levenshtein(self.text[offset:offset + len(look_for)], look_for)
if not found_best or distance < best_distance:
found_best = True
best_distance = distance
best_offset = offset
best_len = len(look_for)
return best_distance, best_offset, best_len
# The following code is from: http://hetland.org/coding/python/levenshtein.py
@ -88,25 +140,3 @@ def levenshtein(a, b):
current[j] = min(add, delete, change)
return current[n]
def minimal_distance(search_in, search_for, start=0, threshold=0):
best_distance = 1000000000
best_offset = -1
best_len = -1
window = 10
rough_acceptable_distance = int(1.5 * window)
acceptable_distance = int(len(search_for) * threshold)
stop = len(search_in)-len(search_for)
for rough_offset in range(start, stop, window):
rough_distance = levenshtein(search_in[rough_offset:rough_offset+len(search_for)], search_for)
if rough_distance < rough_acceptable_distance:
for offset in range(rough_offset-window, rough_offset+window, 1):
distance = levenshtein(search_in[offset:offset+len(search_for)], search_for)
if distance < best_distance:
best_distance = distance
best_offset = offset
best_len = len(search_for)
if best_distance <= acceptable_distance:
return best_distance, best_offset, best_len
return -1, 0, 0

Просмотреть файл

@ -94,7 +94,7 @@ def vad_collector(sample_rate, frame_duration_ms,
triggered = False
voiced_frames = []
for frame in frames:
for frame_index, frame in enumerate(frames):
is_speech = vad.is_speech(frame.bytes, sample_rate)
if not triggered:
@ -122,13 +122,18 @@ def vad_collector(sample_rate, frame_duration_ms,
# audio we've collected.
if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False
yield b''.join([f.bytes for f in voiced_frames])
yield b''.join([f.bytes for f in voiced_frames]), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * frame_index
ring_buffer.clear()
voiced_frames = []
if triggered:
pass
# If we have any leftover voiced audio when we run out of input,
# yield it.
if voiced_frames:
yield b''.join([f.bytes for f in voiced_frames])
yield b''.join([f.bytes for f in voiced_frames]), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * (frame_index + 1)

Просмотреть файл

@ -1,3 +1,4 @@
nltk
six
deepspeech
webrtcvad