зеркало из https://github.com/mozilla/DSAlign.git
Multi-process transcription, .tlog files for transcripts, some light refactoring and reduced logging
This commit is contained in:
Родитель
3031f35561
Коммит
03f5b39d37
|
@ -42,9 +42,9 @@ $ sudo apt-get install build-essential libboost-all-dev cmake zlib1g-dev libbz2-
|
|||
```
|
||||
|
||||
With all requirements fulfilled, there is a script for building and installing KenLM
|
||||
in the right location:
|
||||
and the required DeepSpeech tools in the right location:
|
||||
```bash
|
||||
$ bin/buildkenlm.sh
|
||||
$ bin/lm-dependencies.sh
|
||||
```
|
||||
|
||||
If all went well, the alignment tool will find and use it to automatically create individual
|
||||
|
|
120
align/align.py
120
align/align.py
|
@ -7,11 +7,30 @@ import subprocess
|
|||
import os.path as path
|
||||
import numpy as np
|
||||
import wavTranscriber
|
||||
import multiprocessing
|
||||
from collections import Counter
|
||||
from search import FuzzySearch
|
||||
from text import Alphabet, TextCleaner, levenshtein, similarity
|
||||
from utils import enweight
|
||||
|
||||
model = None
|
||||
sample_rate = 0
|
||||
worker_index = 0
|
||||
|
||||
|
||||
def init_stt(output_graph_path, alphabet_path, lm_path, trie_path, rate):
|
||||
global model, sample_rate
|
||||
sample_rate = rate
|
||||
print('Process {}: Loaded models'.format(os.getpid()))
|
||||
model = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
|
||||
|
||||
|
||||
def stt(time_start, time_end, audio):
|
||||
print('Process {}: Transcribing...'.format(os.getpid()))
|
||||
transcript = wavTranscriber.stt(model, audio, sample_rate)
|
||||
print('Process {}: {}'.format(os.getpid(), transcript))
|
||||
return time_start, time_end, ' '.join(transcript.split())
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
|
||||
|
@ -86,7 +105,7 @@ def main(args):
|
|||
|
||||
output_group = parser.add_argument_group(title='Output options')
|
||||
output_group.add_argument('--output-stt', action="store_true",
|
||||
help='Writes STT transcripts to result file')
|
||||
help='Writes STT transcripts to result file as attribute "transcript"')
|
||||
output_group.add_argument('--output-aligned', action="store_true",
|
||||
help='Writes clean aligned original transcripts to result file')
|
||||
output_group.add_argument('--output-aligned-raw', action="store_true",
|
||||
|
@ -127,8 +146,6 @@ def main(args):
|
|||
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
|
||||
logging.debug("Start")
|
||||
|
||||
fragments = []
|
||||
fragments_cache_path = args.result + '.cache'
|
||||
model_dir = os.path.expanduser(args.stt_model_dir if args.stt_model_dir else 'models/en')
|
||||
logging.debug("Looking for model files in %s..." % model_dir)
|
||||
output_graph_path, alphabet_path, lang_lm_path, lang_trie_path = wavTranscriber.resolve_models(model_dir)
|
||||
|
@ -147,10 +164,11 @@ def main(args):
|
|||
with open(clean_text_path, 'w') as clean_text_file:
|
||||
clean_text_file.write(tc.clean_text)
|
||||
|
||||
if path.exists(fragments_cache_path):
|
||||
logging.debug("Loading cached segment transcripts from %s..." % fragments_cache_path)
|
||||
with open(fragments_cache_path, 'r') as result_file:
|
||||
fragments = json.loads(result_file.read())
|
||||
transcription_log = os.path.splitext(args.audio)[0] + '.tlog'
|
||||
if path.exists(transcription_log):
|
||||
logging.debug("Loading transcription log from %s..." % transcription_log)
|
||||
with open(transcription_log, 'r') as transcriptions_file:
|
||||
fragments = json.loads(transcriptions_file.read())
|
||||
else:
|
||||
kenlm_path = 'dependencies/kenlm/build/bin'
|
||||
if not path.exists(kenlm_path):
|
||||
|
@ -194,7 +212,6 @@ def main(args):
|
|||
|
||||
logging.debug('Loading acoustic model from "%s", alphabet from "%s" and language model from "%s"...' %
|
||||
(output_graph_path, alphabet_path, lm_path))
|
||||
model, _, _ = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
|
||||
|
||||
inference_time = 0.0
|
||||
offset = 0
|
||||
|
@ -203,36 +220,43 @@ def main(args):
|
|||
logging.debug("Transcribing VAD segments...")
|
||||
wave_file = args.audio
|
||||
aggressiveness = int(args.audio_vad_aggressiveness) if args.audio_vad_aggressiveness else 3
|
||||
segments, sample_rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
|
||||
segments, rate, audio_length = wavTranscriber.vad_segment_generator(wave_file, aggressiveness)
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
# Run DeepSpeech on the chunk that just completed VAD
|
||||
segment_buffer, time_start, time_end = segment
|
||||
time_length = time_end - time_start
|
||||
if args.stt_min_duration and time_length < args.stt_min_duration:
|
||||
logging.info('Fragment {}: Audio too short for STT'.format(i))
|
||||
continue
|
||||
if args.stt_max_duration and time_length > args.stt_max_duration:
|
||||
logging.info('Fragment {}: Audio too long for STT'.format(i))
|
||||
continue
|
||||
logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
|
||||
audio = np.frombuffer(segment_buffer, dtype=np.int16)
|
||||
segment_transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
|
||||
segment_transcript = ' '.join(segment_transcript.split())
|
||||
inference_time += segment_inference_time
|
||||
pool = multiprocessing.Pool(initializer=init_stt,
|
||||
initargs=(output_graph_path, alphabet_path, lm_path, trie_path, rate),
|
||||
processes=None)
|
||||
|
||||
def pre_filter():
|
||||
for i, segment in enumerate(segments):
|
||||
segment_buffer, time_start, time_end = segment
|
||||
time_length = time_end - time_start
|
||||
if args.stt_min_duration and time_length < args.stt_min_duration:
|
||||
logging.info('Fragment {}: Audio too short for STT'.format(i))
|
||||
continue
|
||||
if args.stt_max_duration and time_length > args.stt_max_duration:
|
||||
logging.info('Fragment {}: Audio too long for STT'.format(i))
|
||||
continue
|
||||
#logging.debug("Transcribing segment %002d (from %f to %f)..." % (i, time_start / 1000.0, time_end / 1000.0))
|
||||
yield (time_start, time_end, np.frombuffer(segment_buffer, dtype=np.int16))
|
||||
|
||||
samples = list(pre_filter())[10:60]
|
||||
|
||||
transcripts = pool.starmap(stt, samples)
|
||||
|
||||
fragments = []
|
||||
for time_start, time_end, segment_transcript in transcripts:
|
||||
if segment_transcript is None:
|
||||
logging.debug("Segment %002d empty" % i)
|
||||
continue
|
||||
fragments.append({
|
||||
'time-start': time_start,
|
||||
'time-length': time_length,
|
||||
'transcript': segment_transcript
|
||||
'start': time_start,
|
||||
'end': time_end,
|
||||
'transcript': segment_transcript
|
||||
})
|
||||
offset += len(segment_transcript)
|
||||
|
||||
logging.debug("Writing segment transcripts to cache file %s..." % fragments_cache_path)
|
||||
with open(fragments_cache_path, 'w') as result_file:
|
||||
result_file.write(json.dumps(fragments))
|
||||
logging.debug("Writing transcription log to file %s..." % transcription_log)
|
||||
with open(transcription_log, 'w') as transcriptions_file:
|
||||
transcriptions_file.write(json.dumps(fragments))
|
||||
|
||||
search = FuzzySearch(tc.clean_text,
|
||||
max_candidates=args.align_max_candidates,
|
||||
|
@ -296,9 +320,9 @@ def main(args):
|
|||
match = search.find_best(fragment['transcript'], start=start, end=end)
|
||||
match_start, match_end, sws_score, match_substitutions = match
|
||||
if sws_score > (n - 1) / (2 * n):
|
||||
fragment['match_start'] = match_start
|
||||
fragment['match_end'] = match_end
|
||||
fragment['sws_score'] = sws_score
|
||||
fragment['match-start'] = match_start
|
||||
fragment['match-end'] = match_end
|
||||
fragment['sws'] = sws_score
|
||||
fragment['substitutions'] = match_substitutions
|
||||
for f in split_match(fragments[0:index], start=start, end=match_start):
|
||||
yield f
|
||||
|
@ -331,13 +355,13 @@ def main(args):
|
|||
for index in range(len(matched_fragments) + 1):
|
||||
if index > 0:
|
||||
a = matched_fragments[index - 1]
|
||||
a_start, a_end = a['match_start'], a['match_end']
|
||||
a_start, a_end = a['match-start'], a['match-end']
|
||||
else:
|
||||
a = None
|
||||
a_start = a_end = 0
|
||||
if index < len(matched_fragments):
|
||||
b = matched_fragments[index]
|
||||
b_start, b_end = b['match_start'], b['match_end']
|
||||
b_start, b_end = b['match-start'], b['match-end']
|
||||
else:
|
||||
b = None
|
||||
b_start = b_end = len(search.text)
|
||||
|
@ -362,34 +386,34 @@ def main(args):
|
|||
a_best_end = b_best_start = overlap_start + best_index
|
||||
|
||||
if a:
|
||||
a['match_end'] = a_best_end
|
||||
a['match-end'] = a_best_end
|
||||
if b:
|
||||
b['match_start'] = b_best_start
|
||||
b['match-start'] = b_best_start
|
||||
|
||||
for fragment in fragments:
|
||||
index = fragment['index']
|
||||
time_start = fragment['time-start']
|
||||
time_length = fragment['time-length']
|
||||
time_start = fragment['start']
|
||||
time_end = fragment['end']
|
||||
fragment_transcript = fragment['transcript']
|
||||
result_fragment = {
|
||||
'time-start': time_start,
|
||||
'time-length': time_length
|
||||
'start': time_start,
|
||||
'end': time_end
|
||||
}
|
||||
sample_numbers = []
|
||||
|
||||
if should_skip('tlen', index, sample_numbers, lambda: len(fragment_transcript)):
|
||||
continue
|
||||
if args.output_stt:
|
||||
result_fragment['stt'] = fragment_transcript
|
||||
result_fragment['transcript'] = fragment_transcript
|
||||
|
||||
if 'match_start' not in fragment:
|
||||
if 'match-start' not in fragment:
|
||||
skip(index, 'No match for transcript')
|
||||
continue
|
||||
match_start, match_end = fragment['match_start'], fragment['match_end']
|
||||
match_start, match_end = fragment['match-start'], fragment['match-end']
|
||||
original_start = tc.get_original_offset(match_start)
|
||||
original_end = tc.get_original_offset(match_end)
|
||||
result_fragment['text-start'] = original_start
|
||||
result_fragment['text-length'] = original_end - original_start
|
||||
result_fragment['text-end'] = original_end
|
||||
|
||||
if args.output_aligned_raw:
|
||||
result_fragment['aligned-raw'] = original_transcript[original_start:original_end]
|
||||
|
@ -400,7 +424,7 @@ def main(args):
|
|||
if args.output_aligned:
|
||||
result_fragment['aligned'] = fragment_matched
|
||||
|
||||
if should_skip('SWS', index, sample_numbers, lambda: 100 * fragment['sws_score']):
|
||||
if should_skip('SWS', index, sample_numbers, lambda: 100 * fragment['sws']):
|
||||
continue
|
||||
|
||||
if should_skip('WNG', index, sample_numbers,
|
||||
|
@ -438,7 +462,7 @@ def main(args):
|
|||
args.audio,
|
||||
'trim',
|
||||
str(time_start / 1000.0),
|
||||
'='+str((time_start + time_length) / 1000.0)])
|
||||
'='+str(time_end / 1000.0)])
|
||||
with open(args.result, 'w') as result_file:
|
||||
result_file.write(json.dumps(result_fragments))
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import webrtcvad
|
|||
import logging
|
||||
import wavSplit
|
||||
from deepspeech import Model
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
def load_model(models, alphabet, lm, trie):
|
||||
|
@ -24,17 +23,9 @@ def load_model(models, alphabet, lm, trie):
|
|||
LM_ALPHA = 1
|
||||
LM_BETA = 1.85
|
||||
|
||||
model_load_start = timer()
|
||||
ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
|
||||
model_load_end = timer() - model_load_start
|
||||
logging.debug("Loaded model in %0.3fs." % (model_load_end))
|
||||
|
||||
lm_load_start = timer()
|
||||
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
|
||||
lm_load_end = timer() - lm_load_start
|
||||
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))
|
||||
|
||||
return ds, model_load_end, lm_load_end
|
||||
return ds
|
||||
|
||||
|
||||
def stt(ds, audio, fs):
|
||||
|
@ -46,14 +37,9 @@ def stt(ds, audio, fs):
|
|||
:return: tuple (Inference result text, Inference time)
|
||||
"""
|
||||
audio_length = len(audio) * (1 / 16000)
|
||||
|
||||
# Run DeepSpeech
|
||||
logging.debug('Running inference...')
|
||||
inference_start = timer()
|
||||
output = ds.stt(audio, fs)
|
||||
inference_time = timer() - inference_start
|
||||
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_time, audio_length))
|
||||
return output, inference_time
|
||||
return output
|
||||
|
||||
|
||||
def resolve_models(dir_name):
|
||||
|
@ -63,16 +49,9 @@ def resolve_models(dir_name):
|
|||
:return: tuple containing each of the model files (pb, alphabet, lm and trie)
|
||||
"""
|
||||
pb = glob.glob(dir_name + "/*.pb")[0]
|
||||
logging.debug("Found Model: %s" % pb)
|
||||
|
||||
alphabet = glob.glob(dir_name + "/alphabet.txt")[0]
|
||||
logging.debug("Found Alphabet: %s" % alphabet)
|
||||
|
||||
lm = glob.glob(dir_name + "/lm.binary")[0]
|
||||
trie = glob.glob(dir_name + "/trie")[0]
|
||||
logging.debug("Found Language Model: %s" % lm)
|
||||
logging.debug("Found Trie: %s" % trie)
|
||||
|
||||
return pb, alphabet, lm, trie
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче