Multi-process transcription, .tlog files for transcripts, some light refactoring and reduced logging

This commit is contained in:
Tilman Kamp 2019-09-02 12:26:10 +02:00
Родитель 3031f35561
Коммит 03f5b39d37
3 изменённых файлов: 76 добавлений и 73 удалений

Просмотреть файл

@ -42,9 +42,9 @@ $ sudo apt-get install build-essential libboost-all-dev cmake zlib1g-dev libbz2-
```
With all requirements fulfilled, there is a script for building and installing KenLM
in the right location:
and the required DeepSpeech tools in the right location:
```bash
$ bin/buildkenlm.sh
$ bin/lm-dependencies.sh
```
If all went well, the alignment tool will find and use it to automatically create individual

Просмотреть файл

@ -7,11 +7,30 @@ import subprocess
import os.path as path
import numpy as np
import wavTranscriber
import multiprocessing
from collections import Counter
from search import FuzzySearch
from text import Alphabet, TextCleaner, levenshtein, similarity
from utils import enweight
model = None
sample_rate = 0
worker_index = 0
def init_stt(output_graph_path, alphabet_path, lm_path, trie_path, rate):
global model, sample_rate
sample_rate = rate
print('Process {}: Loaded models'.format(os.getpid()))
model = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
def stt(time_start, time_end, audio):
print('Process {}: Transcribing...'.format(os.getpid()))
transcript = wavTranscriber.stt(model, audio, sample_rate)
print('Process {}: {}'.format(os.getpid(), transcript))
return time_start, time_end, ' '.join(transcript.split())
def main(args):
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
@ -86,7 +105,7 @@ def main(args):
output_group = parser.add_argument_group(title='Output options')
output_group.add_argument('--output-stt', action="store_true",
help='Writes STT transcripts to result file')
help='Writes STT transcripts to result file as attribute "transcript"')
output_group.add_argument('--output-aligned', action="store_true",
help='Writes clean aligned original transcripts to result file')
output_group.add_argument('--output-aligned-raw', action="store_true",
@ -127,8 +146,6 @@ def main(args):
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
logging.debug("Start")
fragments = []
fragments_cache_path = args.result + '.cache'
model_dir = os.path.expanduser(args.stt_model_dir if args.stt_model_dir else 'models/en')
logging.debug("Looking for model files in %s..." % model_dir)
output_graph_path, alphabet_path, lang_lm_path, lang_trie_path = wavTranscriber.resolve_models(model_dir)
@ -147,10 +164,11 @@ def main(args):
with open(clean_text_path, 'w') as clean_text_file:
clean_text_file.write(tc.clean_text)
if path.exists(fragments_cache_path):
logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path)
with open(fragments_cache_path, 'r') as result_file:
fragments = json.loads(result_file.read())
transcription_log = os.path.splitext(args.audio)[0] + '.tlog'
if path.exists(transcription_log):
logging.debug("Loading transcription log from %s..." % transcription_log)
with open(transcription_log, 'r') as transcriptions_file:
fragments = json.loads(transcriptions_file.read())
else:
kenlm_path = 'dependencies/kenlm/build/bin'
if not path.exists(kenlm_path):
@ -194,7 +212,6 @@ def main(args):
logging.debug('Loading acoustic model from "%s", alphabet from "%s" and language model from "%s"...' %
(output_graph_path, alphabet_path, lm_path))
model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
inference_time = 0.0
offset = 0
@ -203,36 +220,43 @@ def main(args):
logging.debug("Transcribing VAD segments...")
wave_file = args.audio
aggressiveness = int(args.audio_vad_aggressiveness) if args.audio_vad_aggressiveness else 3
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
segments, rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
for i, segment in enumerate(segments):
# Run DeepSpeech on the chunk that just completed VAD
segment_buffer, time_start, time_end = segment
time_length = time_end - time_start
if args.stt_min_duration and time_length < args.stt_min_duration:
logging.info('Fragment {}: Audio too short for STT'.format(i))
continue
if args.stt_max_duration and time_length > args.stt_max_duration:
logging.info('Fragment {}: Audio too long for STT'.format(i))
continue
logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
audio = np.frombuffer(segment_buffer, dtype=np.int16)
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
segment_transcript = ' '.join(segment_transcript.split())
inference_time += segment_inference_time
pool = multiprocessing.Pool(initializer=init_stt,
initargs=(output_graph_path, alphabet_path, lm_path, trie_path, rate),
processes=None)
def pre_filter():
for i, segment in enumerate(segments):
segment_buffer, time_start, time_end = segment
time_length = time_end - time_start
if args.stt_min_duration and time_length < args.stt_min_duration:
logging.info('Fragment {}: Audio too short for STT'.format(i))
continue
if args.stt_max_duration and time_length > args.stt_max_duration:
logging.info('Fragment {}: Audio too long for STT'.format(i))
continue
#logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
yield (time_start, time_end, np.frombuffer(segment_buffer, dtype=np.int16))
samples = list(pre_filter())[10:60]
transcripts = pool.starmap(stt, samples)
fragments = []
for time_start, time_end, segment_transcript in transcripts:
if segment_transcript is None:
logging.debug("Segment %002d empty" % i)
continue
fragments.append({
'time-start': time_start,
'time-length': time_length,
'transcript': segment_transcript
'start': time_start,
'end': time_end,
'transcript': segment_transcript
})
offset += len(segment_transcript)
logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path)
with open(fragments_cache_path, 'w') as result_file:
result_file.write(json.dumps(fragments))
logging.debug("Writing transcription log to file %s..." % transcription_log)
with open(transcription_log, 'w') as transcriptions_file:
transcriptions_file.write(json.dumps(fragments))
search = FuzzySearch(tc.clean_text,
max_candidates=args.align_max_candidates,
@ -296,9 +320,9 @@ def main(args):
match = search.find_best(fragment['transcript'], start=start, end=end)
match_start, match_end, sws_score, match_substitutions = match
if sws_score > (n - 1) / (2 * n):
fragment['match_start'] = match_start
fragment['match_end'] = match_end
fragment['sws_score'] = sws_score
fragment['match-start'] = match_start
fragment['match-end'] = match_end
fragment['sws'] = sws_score
fragment['substitutions'] = match_substitutions
for f in split_match(fragments[0:index], start=start, end=match_start):
yield f
@ -331,13 +355,13 @@ def main(args):
for index in range(len(matched_fragments) + 1):
if index > 0:
a = matched_fragments[index - 1]
a_start, a_end = a['match_start'], a['match_end']
a_start, a_end = a['match-start'], a['match-end']
else:
a = None
a_start = a_end = 0
if index < len(matched_fragments):
b = matched_fragments[index]
b_start, b_end = b['match_start'], b['match_end']
b_start, b_end = b['match-start'], b['match-end']
else:
b = None
b_start = b_end = len(search.text)
@ -362,34 +386,34 @@ def main(args):
a_best_end = b_best_start = overlap_start + best_index
if a:
a['match_end'] = a_best_end
a['match-end'] = a_best_end
if b:
b['match_start'] = b_best_start
b['match-start'] = b_best_start
for fragment in fragments:
index = fragment['index']
time_start = fragment['time-start']
time_length = fragment['time-length']
time_start = fragment['start']
time_end = fragment['end']
fragment_transcript = fragment['transcript']
result_fragment = {
'time-start': time_start,
'time-length': time_length
'start': time_start,
'end': time_end
}
sample_numbers = []
if should_skip('tlen', index, sample_numbers, lambda: len(fragment_transcript)):
continue
if args.output_stt:
result_fragment['stt'] = fragment_transcript
result_fragment['transcript'] = fragment_transcript
if 'match_start' not in fragment:
if 'match-start' not in fragment:
skip(index, 'No match for transcript')
continue
match_start, match_end = fragment['match_start'], fragment['match_end']
match_start, match_end = fragment['match-start'], fragment['match-end']
original_start = tc.get_original_offset(match_start)
original_end = tc.get_original_offset(match_end)
result_fragment['text-start'] = original_start
result_fragment['text-length'] = original_end - original_start
result_fragment['text-end'] = original_end
if args.output_aligned_raw:
result_fragment['aligned-raw'] = original_transcript[original_start:original_end]
@ -400,7 +424,7 @@ def main(args):
if args.output_aligned:
result_fragment['aligned'] = fragment_matched
if should_skip('SWS', index, sample_numbers, lambda: 100 * fragment['sws_score']):
if should_skip('SWS', index, sample_numbers, lambda: 100 * fragment['sws']):
continue
if should_skip('WNG', index, sample_numbers,
@ -438,7 +462,7 @@ def main(args):
args.audio,
'trim',
str(time_start / 1000.0),
'='+str((time_start + time_length) / 1000.0)])
'='+str(time_end / 1000.0)])
with open(args.result, 'w') as result_file:
result_file.write(json.dumps(result_fragments))

Просмотреть файл

@ -3,7 +3,6 @@ import webrtcvad
import logging
import wavSplit
from deepspeech import Model
from timeit import default_timer as timer
def load_model(models, alphabet, lm, trie):
@ -24,17 +23,9 @@ def load_model(models, alphabet, lm, trie):
LM_ALPHA = 1
LM_BETA = 1.85
model_load_start = timer()
ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
model_load_end = timer() - model_load_start
logging.debug("Loaded model in %0.3fs." % (model_load_end))
lm_load_start = timer()
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
lm_load_end = timer() - lm_load_start
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))
return ds, model_load_end, lm_load_end
return ds
def stt(ds, audio, fs):
@ -46,14 +37,9 @@ def stt(ds, audio, fs):
:return: tuple (Inference result text, Inference time)
"""
audio_length = len(audio) * (1 / 16000)
# Run DeepSpeech
logging.debug('Running inference...')
inference_start = timer()
output = ds.stt(audio, fs)
inference_time = timer() - inference_start
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_time, audio_length))
return output, inference_time
return output
def resolve_models(dir_name):
@ -63,16 +49,9 @@ def resolve_models(dir_name):
:return: tuple containing each of the model files (pb, alphabet, lm and trie)
"""
pb = glob.glob(dir_name + "/*.pb")[0]
logging.debug("Found Model: %s" % pb)
alphabet = glob.glob(dir_name + "/alphabet.txt")[0]
logging.debug("Found Alphabet: %s" % alphabet)
lm = glob.glob(dir_name + "/lm.binary")[0]
trie = glob.glob(dir_name + "/trie")[0]
logging.debug("Found Language Model: %s" % lm)
logging.debug("Found Trie: %s" % trie)
return pb, alphabet, lm, trie