зеркало из https://github.com/mozilla/DSAlign.git
This commit is contained in:
Родитель
c1c222c2db
Коммит
6e29fa594b
|
@ -1,8 +1,10 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import text
|
import text
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
import os.path as path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import wavTranscriber
|
import wavTranscriber
|
||||||
|
|
||||||
|
@ -10,9 +12,11 @@ import wavTranscriber
|
||||||
def main(args):
|
def main(args):
|
||||||
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
|
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
|
||||||
parser.add_argument('audio', type=str,
|
parser.add_argument('audio', type=str,
|
||||||
help='Path to speech audio (WAV format)')
|
help='Source path of speech audio (WAV format)')
|
||||||
parser.add_argument('transcript', type=str,
|
parser.add_argument('transcript', type=str,
|
||||||
help='Path to original transcript')
|
help='Source path of original transcript (plain text)')
|
||||||
|
parser.add_argument('result', type=str,
|
||||||
|
help='Target path of alignment result file (JSON)')
|
||||||
parser.add_argument('--aggressive', type=int, choices=range(4), required=False,
|
parser.add_argument('--aggressive', type=int, choices=range(4), required=False,
|
||||||
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
|
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
|
||||||
parser.add_argument('--model', required=False,
|
parser.add_argument('--model', required=False,
|
||||||
|
@ -23,43 +27,70 @@ def main(args):
|
||||||
|
|
||||||
# Debug helpers
|
# Debug helpers
|
||||||
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
|
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
|
||||||
|
logging.debug("Start")
|
||||||
|
|
||||||
# Loading model
|
fragments = []
|
||||||
|
fragments_cache_path = args.result + '.cache'
|
||||||
model_dir = os.path.expanduser(args.model if args.model else 'models/en')
|
model_dir = os.path.expanduser(args.model if args.model else 'models/en')
|
||||||
output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir)
|
logging.debug("Looking for model files in %s..." % model_dir)
|
||||||
model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
|
output_graph_path, alphabet_path, lm_path, trie_path = wavTranscriber.resolve_models(model_dir)
|
||||||
alphabet = text.Alphabet(alphabet)
|
logging.debug("Loading alphabet from %s..." % alphabet_path)
|
||||||
|
alphabet = text.Alphabet(alphabet_path)
|
||||||
|
|
||||||
inference_time = 0.0
|
if path.exists(fragments_cache_path):
|
||||||
|
logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path)
|
||||||
|
with open(fragments_cache_path, 'r') as result_file:
|
||||||
|
fragments = json.loads(result_file.read())
|
||||||
|
else:
|
||||||
|
logging.debug("Loading model from %s..." % model_dir)
|
||||||
|
model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
|
||||||
|
|
||||||
# Run VAD on the input file
|
inference_time = 0.0
|
||||||
wave_file = args.audio
|
offset = 0
|
||||||
aggressiveness = int(args.aggressive) if args.aggressive else 3
|
|
||||||
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
|
|
||||||
|
|
||||||
|
# Run VAD on the input file
|
||||||
|
logging.debug("Transcribing VAD segments...")
|
||||||
|
wave_file = args.audio
|
||||||
|
aggressiveness = int(args.aggressive) if args.aggressive else 3
|
||||||
|
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
|
||||||
|
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
# Run DeepSpeech on the chunk that just completed VAD
|
||||||
|
segment_buffer, time_start, time_end = segment
|
||||||
|
logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
|
||||||
|
|
||||||
|
audio = np.frombuffer(segment_buffer, dtype=np.int16)
|
||||||
|
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
|
||||||
|
if segment_transcript is None:
|
||||||
|
logging.debug("Segment %002d empty" % i)
|
||||||
|
continue
|
||||||
|
inference_time += segment_inference_time
|
||||||
|
fragments.append({
|
||||||
|
'time_start': time_start,
|
||||||
|
'time_end': time_end,
|
||||||
|
'transcript': segment_transcript,
|
||||||
|
'offset': offset
|
||||||
|
})
|
||||||
|
offset += len(segment_transcript)
|
||||||
|
|
||||||
|
logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path)
|
||||||
|
with open(fragments_cache_path, 'w') as result_file:
|
||||||
|
result_file.write(json.dumps(fragments))
|
||||||
|
|
||||||
|
logging.debug("Loading original transcript from %s..." % args.transcript)
|
||||||
with open(args.transcript, 'r') as transcript_file:
|
with open(args.transcript, 'r') as transcript_file:
|
||||||
original_transcript = transcript_file.read()
|
original_transcript = transcript_file.read()
|
||||||
original_transcript = ' '.join(original_transcript.lower().split())
|
original_transcript = ' '.join(original_transcript.lower().split())
|
||||||
original_transcript = alphabet.filter(original_transcript)
|
original_transcript = alphabet.filter(original_transcript)
|
||||||
|
ls = text.LevenshteinSearch(original_transcript)
|
||||||
position = 0
|
start = 0
|
||||||
|
for fragment in fragments:
|
||||||
for i, segment in enumerate(segments):
|
logging.debug('STT Transcribed: %s' % fragment['transcript'])
|
||||||
# Run DeepSpeech on the chunk that just completed VAD
|
match_distance, match_offset, match_len = ls.find_best(fragment['transcript'])
|
||||||
logging.debug("Transcribing segment %002d..." % i)
|
if match_offset >= 0:
|
||||||
audio = np.frombuffer(segment, dtype=np.int16)
|
fragment['original'] = original_transcript[match_offset:match_offset+match_len]
|
||||||
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
|
logging.debug(' Original: %s' % fragment['original'])
|
||||||
inference_time += segment_inference_time
|
start = match_offset+match_len
|
||||||
|
|
||||||
logging.debug("Looking for segment transcript in original transcript...")
|
|
||||||
distance, found_offset, found_len = \
|
|
||||||
text.minimal_distance(original_transcript,
|
|
||||||
segment_transcript,
|
|
||||||
start=position,
|
|
||||||
threshold=0.1)
|
|
||||||
logging.info("Segment transcript: %s" % segment_transcript)
|
|
||||||
logging.info("Segment found: %s" % original_transcript[found_offset:found_offset+found_len])
|
|
||||||
logging.info("--")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import logging
|
import logging
|
||||||
|
from nltk import ngrams
|
||||||
|
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
|
@ -55,6 +56,57 @@ class Alphabet(object):
|
||||||
return self._config_file
|
return self._config_file
|
||||||
|
|
||||||
|
|
||||||
|
class TextCleaner:
|
||||||
|
def __init__(self, original_text, alphabet, to_lower=True, normalize_space=True):
|
||||||
|
self.original_text = original_text
|
||||||
|
clean_text = original_text
|
||||||
|
if to_lower:
|
||||||
|
clean_text = clean_text.lower()
|
||||||
|
if normalize_space:
|
||||||
|
clean_text = ' '.join(clean_text.split())
|
||||||
|
self.clean_text = alphabet.filter(clean_text)
|
||||||
|
|
||||||
|
def get_original_offset(self, clean_offset):
|
||||||
|
return clean_offset
|
||||||
|
|
||||||
|
|
||||||
|
class LevenshteinSearch:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
self.ngrams = {}
|
||||||
|
for i, ngram in enumerate(ngrams(' ' + text + ' ', 3)):
|
||||||
|
if ngram in self.ngrams:
|
||||||
|
ngram_bucket = self.ngrams[ngram]
|
||||||
|
else:
|
||||||
|
ngram_bucket = self.ngrams[ngram] = []
|
||||||
|
ngram_bucket.append(i)
|
||||||
|
|
||||||
|
def find_best(self, look_for, start=0, stop=-1, threshold=0):
|
||||||
|
stop = len(self.text) if stop < 0 else stop
|
||||||
|
window_size = len(look_for)
|
||||||
|
windows = {}
|
||||||
|
for i, ngram in enumerate(ngrams(' ' + look_for + ' ', 3)):
|
||||||
|
if ngram in self.ngrams:
|
||||||
|
ngram_bucket = self.ngrams[ngram]
|
||||||
|
for occurrence in ngram_bucket:
|
||||||
|
if occurrence < start or occurrence > stop:
|
||||||
|
continue
|
||||||
|
window = occurrence // window_size
|
||||||
|
windows[window] = (windows[window] + 1) if window in windows else 1
|
||||||
|
candidate_windows = sorted(windows.keys(), key=lambda w: windows[w], reverse=True)
|
||||||
|
found_best = False
|
||||||
|
best_distance = -1
|
||||||
|
best_offset = -1
|
||||||
|
best_len = -1
|
||||||
|
for window in candidate_windows[0:4]:
|
||||||
|
for offset in range(int((window-0.5)*window_size), int((window+0.5)*window_size)):
|
||||||
|
distance = levenshtein(self.text[offset:offset + len(look_for)], look_for)
|
||||||
|
if not found_best or distance < best_distance:
|
||||||
|
found_best = True
|
||||||
|
best_distance = distance
|
||||||
|
best_offset = offset
|
||||||
|
best_len = len(look_for)
|
||||||
|
return best_distance, best_offset, best_len
|
||||||
|
|
||||||
|
|
||||||
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
||||||
|
@ -88,25 +140,3 @@ def levenshtein(a, b):
|
||||||
current[j] = min(add, delete, change)
|
current[j] = min(add, delete, change)
|
||||||
|
|
||||||
return current[n]
|
return current[n]
|
||||||
|
|
||||||
|
|
||||||
def minimal_distance(search_in, search_for, start=0, threshold=0):
|
|
||||||
best_distance = 1000000000
|
|
||||||
best_offset = -1
|
|
||||||
best_len = -1
|
|
||||||
window = 10
|
|
||||||
rough_acceptable_distance = int(1.5 * window)
|
|
||||||
acceptable_distance = int(len(search_for) * threshold)
|
|
||||||
stop = len(search_in)-len(search_for)
|
|
||||||
for rough_offset in range(start, stop, window):
|
|
||||||
rough_distance = levenshtein(search_in[rough_offset:rough_offset+len(search_for)], search_for)
|
|
||||||
if rough_distance < rough_acceptable_distance:
|
|
||||||
for offset in range(rough_offset-window, rough_offset+window, 1):
|
|
||||||
distance = levenshtein(search_in[offset:offset+len(search_for)], search_for)
|
|
||||||
if distance < best_distance:
|
|
||||||
best_distance = distance
|
|
||||||
best_offset = offset
|
|
||||||
best_len = len(search_for)
|
|
||||||
if best_distance <= acceptable_distance:
|
|
||||||
return best_distance, best_offset, best_len
|
|
||||||
return -1, 0, 0
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ def vad_collector(sample_rate, frame_duration_ms,
|
||||||
triggered = False
|
triggered = False
|
||||||
|
|
||||||
voiced_frames = []
|
voiced_frames = []
|
||||||
for frame in frames:
|
for frame_index, frame in enumerate(frames):
|
||||||
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
||||||
|
|
||||||
if not triggered:
|
if not triggered:
|
||||||
|
@ -122,13 +122,18 @@ def vad_collector(sample_rate, frame_duration_ms,
|
||||||
# audio we've collected.
|
# audio we've collected.
|
||||||
if num_unvoiced > threshold * ring_buffer.maxlen:
|
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||||
triggered = False
|
triggered = False
|
||||||
yield b''.join([f.bytes for f in voiced_frames])
|
yield b''.join([f.bytes for f in voiced_frames]), \
|
||||||
|
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||||
|
frame_duration_ms * frame_index
|
||||||
ring_buffer.clear()
|
ring_buffer.clear()
|
||||||
voiced_frames = []
|
voiced_frames = []
|
||||||
|
|
||||||
if triggered:
|
if triggered:
|
||||||
pass
|
pass
|
||||||
# If we have any leftover voiced audio when we run out of input,
|
# If we have any leftover voiced audio when we run out of input,
|
||||||
# yield it.
|
# yield it.
|
||||||
if voiced_frames:
|
if voiced_frames:
|
||||||
yield b''.join([f.bytes for f in voiced_frames])
|
yield b''.join([f.bytes for f in voiced_frames]), \
|
||||||
|
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||||
|
frame_duration_ms * (frame_index + 1)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
nltk
|
||||||
six
|
six
|
||||||
deepspeech
|
deepspeech
|
||||||
webrtcvad
|
webrtcvad
|
||||||
|
|
Загрузка…
Ссылка в новой задаче