зеркало из https://github.com/mozilla/DSAlign.git
This commit is contained in:
Родитель
c1c222c2db
Коммит
6e29fa594b
|
@ -1,8 +1,10 @@
|
|||
import sys
|
||||
import os
|
||||
import text
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import os.path as path
|
||||
import numpy as np
|
||||
import wavTranscriber
|
||||
|
||||
|
@ -10,9 +12,11 @@ import wavTranscriber
|
|||
def main(args):
|
||||
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
|
||||
parser.add_argument('audio', type=str,
|
||||
help='Path to speech audio (WAV format)')
|
||||
help='Source path of speech audio (WAV format)')
|
||||
parser.add_argument('transcript', type=str,
|
||||
help='Path to original transcript')
|
||||
help='Source path of original transcript (plain text)')
|
||||
parser.add_argument('result', type=str,
|
||||
help='Target path of alignment result file (JSON)')
|
||||
parser.add_argument('--aggressive', type=int, choices=range(4), required=False,
|
||||
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
|
||||
parser.add_argument('--model', required=False,
|
||||
|
@ -23,43 +27,70 @@ def main(args):
|
|||
|
||||
# Debug helpers
|
||||
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
|
||||
logging.debug("Start")
|
||||
|
||||
# Loading model
|
||||
fragments = []
|
||||
fragments_cache_path = args.result + '.cache'
|
||||
model_dir = os.path.expanduser(args.model if args.model else 'models/en')
|
||||
output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir)
|
||||
model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
|
||||
alphabet = text.Alphabet(alphabet)
|
||||
logging.debug("Looking for model files in %s..." % model_dir)
|
||||
output_graph_path, alphabet_path, lm_path, trie_path = wavTranscriber.resolve_models(model_dir)
|
||||
logging.debug("Loading alphabet from %s..." % alphabet_path)
|
||||
alphabet = text.Alphabet(alphabet_path)
|
||||
|
||||
inference_time = 0.0
|
||||
if path.exists(fragments_cache_path):
|
||||
logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path)
|
||||
with open(fragments_cache_path, 'r') as result_file:
|
||||
fragments = json.loads(result_file.read())
|
||||
else:
|
||||
logging.debug("Loading model from %s..." % model_dir)
|
||||
model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
|
||||
|
||||
# Run VAD on the input file
|
||||
wave_file = args.audio
|
||||
aggressiveness = int(args.aggressive) if args.aggressive else 3
|
||||
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
|
||||
inference_time = 0.0
|
||||
offset = 0
|
||||
|
||||
# Run VAD on the input file
|
||||
logging.debug("Transcribing VAD segments...")
|
||||
wave_file = args.audio
|
||||
aggressiveness = int(args.aggressive) if args.aggressive else 3
|
||||
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
# Run DeepSpeech on the chunk that just completed VAD
|
||||
segment_buffer, time_start, time_end = segment
|
||||
logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
|
||||
|
||||
audio = np.frombuffer(segment_buffer, dtype=np.int16)
|
||||
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
|
||||
if segment_transcript is None:
|
||||
logging.debug("Segment %002d empty" % i)
|
||||
continue
|
||||
inference_time += segment_inference_time
|
||||
fragments.append({
|
||||
'time_start': time_start,
|
||||
'time_end': time_end,
|
||||
'transcript': segment_transcript,
|
||||
'offset': offset
|
||||
})
|
||||
offset += len(segment_transcript)
|
||||
|
||||
logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path)
|
||||
with open(fragments_cache_path, 'w') as result_file:
|
||||
result_file.write(json.dumps(fragments))
|
||||
|
||||
logging.debug("Loading original transcript from %s..." % args.transcript)
|
||||
with open(args.transcript, 'r') as transcript_file:
|
||||
original_transcript = transcript_file.read()
|
||||
original_transcript = ' '.join(original_transcript.lower().split())
|
||||
original_transcript = alphabet.filter(original_transcript)
|
||||
|
||||
position = 0
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
# Run DeepSpeech on the chunk that just completed VAD
|
||||
logging.debug("Transcribing segment %002d..." % i)
|
||||
audio = np.frombuffer(segment, dtype=np.int16)
|
||||
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
|
||||
inference_time += segment_inference_time
|
||||
|
||||
logging.debug("Looking for segment transcript in original transcript...")
|
||||
distance, found_offset, found_len = \
|
||||
text.minimal_distance(original_transcript,
|
||||
segment_transcript,
|
||||
start=position,
|
||||
threshold=0.1)
|
||||
logging.info("Segment transcript: %s" % segment_transcript)
|
||||
logging.info("Segment found: %s" % original_transcript[found_offset:found_offset+found_len])
|
||||
logging.info("--")
|
||||
ls = text.LevenshteinSearch(original_transcript)
|
||||
start = 0
|
||||
for fragment in fragments:
|
||||
logging.debug('STT Transcribed: %s' % fragment['transcript'])
|
||||
match_distance, match_offset, match_len = ls.find_best(fragment['transcript'])
|
||||
if match_offset >= 0:
|
||||
fragment['original'] = original_transcript[match_offset:match_offset+match_len]
|
||||
logging.debug(' Original: %s' % fragment['original'])
|
||||
start = match_offset+match_len
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function
|
|||
|
||||
import codecs
|
||||
import logging
|
||||
from nltk import ngrams
|
||||
|
||||
from six.moves import range
|
||||
|
||||
|
@ -55,6 +56,57 @@ class Alphabet(object):
|
|||
return self._config_file
|
||||
|
||||
|
||||
class TextCleaner:
|
||||
def __init__(self, original_text, alphabet, to_lower=True, normalize_space=True):
|
||||
self.original_text = original_text
|
||||
clean_text = original_text
|
||||
if to_lower:
|
||||
clean_text = clean_text.lower()
|
||||
if normalize_space:
|
||||
clean_text = ' '.join(clean_text.split())
|
||||
self.clean_text = alphabet.filter(clean_text)
|
||||
|
||||
def get_original_offset(self, clean_offset):
|
||||
return clean_offset
|
||||
|
||||
|
||||
class LevenshteinSearch:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
self.ngrams = {}
|
||||
for i, ngram in enumerate(ngrams(' ' + text + ' ', 3)):
|
||||
if ngram in self.ngrams:
|
||||
ngram_bucket = self.ngrams[ngram]
|
||||
else:
|
||||
ngram_bucket = self.ngrams[ngram] = []
|
||||
ngram_bucket.append(i)
|
||||
|
||||
def find_best(self, look_for, start=0, stop=-1, threshold=0):
|
||||
stop = len(self.text) if stop < 0 else stop
|
||||
window_size = len(look_for)
|
||||
windows = {}
|
||||
for i, ngram in enumerate(ngrams(' ' + look_for + ' ', 3)):
|
||||
if ngram in self.ngrams:
|
||||
ngram_bucket = self.ngrams[ngram]
|
||||
for occurrence in ngram_bucket:
|
||||
if occurrence < start or occurrence > stop:
|
||||
continue
|
||||
window = occurrence // window_size
|
||||
windows[window] = (windows[window] + 1) if window in windows else 1
|
||||
candidate_windows = sorted(windows.keys(), key=lambda w: windows[w], reverse=True)
|
||||
found_best = False
|
||||
best_distance = -1
|
||||
best_offset = -1
|
||||
best_len = -1
|
||||
for window in candidate_windows[0:4]:
|
||||
for offset in range(int((window-0.5)*window_size), int((window+0.5)*window_size)):
|
||||
distance = levenshtein(self.text[offset:offset + len(look_for)], look_for)
|
||||
if not found_best or distance < best_distance:
|
||||
found_best = True
|
||||
best_distance = distance
|
||||
best_offset = offset
|
||||
best_len = len(look_for)
|
||||
return best_distance, best_offset, best_len
|
||||
|
||||
|
||||
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
||||
|
@ -88,25 +140,3 @@ def levenshtein(a, b):
|
|||
current[j] = min(add, delete, change)
|
||||
|
||||
return current[n]
|
||||
|
||||
|
||||
def minimal_distance(search_in, search_for, start=0, threshold=0):
|
||||
best_distance = 1000000000
|
||||
best_offset = -1
|
||||
best_len = -1
|
||||
window = 10
|
||||
rough_acceptable_distance = int(1.5 * window)
|
||||
acceptable_distance = int(len(search_for) * threshold)
|
||||
stop = len(search_in)-len(search_for)
|
||||
for rough_offset in range(start, stop, window):
|
||||
rough_distance = levenshtein(search_in[rough_offset:rough_offset+len(search_for)], search_for)
|
||||
if rough_distance < rough_acceptable_distance:
|
||||
for offset in range(rough_offset-window, rough_offset+window, 1):
|
||||
distance = levenshtein(search_in[offset:offset+len(search_for)], search_for)
|
||||
if distance < best_distance:
|
||||
best_distance = distance
|
||||
best_offset = offset
|
||||
best_len = len(search_for)
|
||||
if best_distance <= acceptable_distance:
|
||||
return best_distance, best_offset, best_len
|
||||
return -1, 0, 0
|
||||
|
|
|
@ -94,7 +94,7 @@ def vad_collector(sample_rate, frame_duration_ms,
|
|||
triggered = False
|
||||
|
||||
voiced_frames = []
|
||||
for frame in frames:
|
||||
for frame_index, frame in enumerate(frames):
|
||||
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
||||
|
||||
if not triggered:
|
||||
|
@ -122,13 +122,18 @@ def vad_collector(sample_rate, frame_duration_ms,
|
|||
# audio we've collected.
|
||||
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||
triggered = False
|
||||
yield b''.join([f.bytes for f in voiced_frames])
|
||||
yield b''.join([f.bytes for f in voiced_frames]), \
|
||||
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * frame_index
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
|
||||
if triggered:
|
||||
pass
|
||||
# If we have any leftover voiced audio when we run out of input,
|
||||
# yield it.
|
||||
if voiced_frames:
|
||||
yield b''.join([f.bytes for f in voiced_frames])
|
||||
yield b''.join([f.bytes for f in voiced_frames]), \
|
||||
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * (frame_index + 1)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
nltk
|
||||
six
|
||||
deepspeech
|
||||
webrtcvad
|
||||
|
|
Загрузка…
Ссылка в новой задаче