зеркало из https://github.com/mozilla/DSAlign.git
Light refactoring to integrate DeepSpeech audio helpers. Additional VAD parameters.
This commit is contained in:
Родитель
d17489c29e
Коммит
fba9c8971d
453
align/align.py
453
align/align.py
|
@ -3,40 +3,40 @@ import sys
|
|||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import deepspeech
|
||||
import subprocess
|
||||
import os.path as path
|
||||
import numpy as np
|
||||
import textdistance
|
||||
import wavSplit
|
||||
import wavTranscriber
|
||||
import multiprocessing
|
||||
from collections import Counter
|
||||
from search import FuzzySearch
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from text import Alphabet, TextCleaner, levenshtein, similarity
|
||||
from utils import enweight
|
||||
from audio import DEFAULT_RATE, read_frames_from_file, vad_split
|
||||
|
||||
algos = ['WNG', 'jaro_winkler', 'editex', 'levenshtein', 'mra', 'hamming']
|
||||
sim_desc = 'From 0.0 (not equal at all) to 100.0 (totally equal)'
|
||||
named_numbers = {
|
||||
BEAM_WIDTH = 500
|
||||
LM_ALPHA = 1
|
||||
LM_BETA = 1.85
|
||||
|
||||
ALGORITHMS = ['WNG', 'jaro_winkler', 'editex', 'levenshtein', 'mra', 'hamming']
|
||||
SIM_DESC = 'From 0.0 (not equal at all) to 100.0 (totally equal)'
|
||||
NAMED_NUMBERS = {
|
||||
'tlen': ('transcript length', int, None),
|
||||
'mlen': ('match length', int, None),
|
||||
'SWS': ('Smith-Waterman score', float, 'From 0.0 (not equal at all) to 100.0+ (pretty equal)'),
|
||||
'WNG': ('weighted N-gram similarity', float, sim_desc),
|
||||
'jaro_winkler': ('Jaro-Winkler similarity', float, sim_desc),
|
||||
'editex': ('Editex similarity', float, sim_desc),
|
||||
'levenshtein': ('Levenshtein similarity', float, sim_desc),
|
||||
'mra': ('MRA similarity', float, sim_desc),
|
||||
'hamming': ('Hamming similarity', float, sim_desc),
|
||||
'WNG': ('weighted N-gram similarity', float, SIM_DESC),
|
||||
'jaro_winkler': ('Jaro-Winkler similarity', float, SIM_DESC),
|
||||
'editex': ('Editex similarity', float, SIM_DESC),
|
||||
'levenshtein': ('Levenshtein similarity', float, SIM_DESC),
|
||||
'mra': ('MRA similarity', float, SIM_DESC),
|
||||
'hamming': ('Hamming similarity', float, SIM_DESC),
|
||||
'CER': ('character error rate', float, 'From 0.0 (no different words) to 100.0+ (total miss)'),
|
||||
'WER': ('word error rate', float, 'From 0.0 (no wrong characters) to 100.0+ (total miss)')
|
||||
}
|
||||
|
||||
args = None
|
||||
model = None
|
||||
sample_rate = 0
|
||||
alphabet = None
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
logging.fatal(message)
|
||||
|
@ -61,27 +61,23 @@ def read_script(script_path):
|
|||
return tc
|
||||
|
||||
|
||||
def init_stt(output_graph_path, alphabet_path, lm_path, trie_path, rate):
|
||||
global model, sample_rate
|
||||
sample_rate = rate
|
||||
model = None
|
||||
|
||||
def init_stt(output_graph_path, lm_path, trie_path):
|
||||
global model
|
||||
model = deepspeech.Model(output_graph_path, BEAM_WIDTH)
|
||||
model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA)
|
||||
logging.debug('Process {}: Loaded models'.format(os.getpid()))
|
||||
model = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
|
||||
|
||||
|
||||
def stt(sample):
|
||||
time_start, time_end, audio = sample
|
||||
logging.debug('Process {}: Transcribing...'.format(os.getpid()))
|
||||
transcript = wavTranscriber.stt(model, audio, sample_rate)
|
||||
transcript = model.stt(audio)
|
||||
logging.debug('Process {}: {}'.format(os.getpid(), transcript))
|
||||
return time_start, time_end, ' '.join(transcript.split())
|
||||
|
||||
|
||||
def init_align(w_args, w_alphabet):
|
||||
global args, alphabet
|
||||
args = w_args
|
||||
alphabet = w_alphabet
|
||||
|
||||
|
||||
def align(triple):
|
||||
tlog, script, aligned = triple
|
||||
|
||||
|
@ -166,7 +162,7 @@ def align(triple):
|
|||
max_ngram_size=args.align_wng_max_size,
|
||||
size_factor=args.align_wng_size_factor,
|
||||
position_factor=args.align_wng_position_factor)
|
||||
elif algo in algos:
|
||||
elif algo in ALGORITHMS:
|
||||
algo_impl = similarity_algos[algo] = getattr(textdistance, algo).normalized_similarity
|
||||
else:
|
||||
logging.fatal('Unknown similarity metric "{}"'.format(algo))
|
||||
|
@ -257,7 +253,7 @@ def align(triple):
|
|||
show.insert(0, '{}: {:.2f}'.format(number_key, val))
|
||||
if should_output:
|
||||
fragment[kl] = val
|
||||
reason_base = '{} ({})'.format(named_numbers[number_key][0], number_key)
|
||||
reason_base = '{} ({})'.format(NAMED_NUMBERS[number_key][0], number_key)
|
||||
reason = None
|
||||
if min_val and val < min_val:
|
||||
reason = reason_base + ' too low'
|
||||
|
@ -321,7 +317,7 @@ def align(triple):
|
|||
continue
|
||||
|
||||
should_skip = False
|
||||
for algo in algos:
|
||||
for algo in ALGORITHMS:
|
||||
should_skip = should_skip or apply_number(algo, index, result_fragment, sample_numbers,
|
||||
lambda: 100 * phrase_similarity(algo,
|
||||
fragment_matched,
|
||||
|
@ -361,7 +357,194 @@ def align(triple):
|
|||
|
||||
|
||||
def main():
|
||||
global args, alphabet
|
||||
# Debug helpers
|
||||
logging.basicConfig(stream=sys.stdout, level=args.loglevel if args.loglevel else 20)
|
||||
|
||||
def progress(iter, **kwargs):
|
||||
return iter if args.no_progress else tqdm(iter, **kwargs)
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not path.isabs(spec_path):
|
||||
spec_path = path.join(base_path, spec_path)
|
||||
return spec_path
|
||||
|
||||
def exists(file_path):
|
||||
if file_path is None:
|
||||
return False
|
||||
return os.path.isfile(file_path)
|
||||
|
||||
to_prepare = []
|
||||
|
||||
def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
|
||||
if exists(aligned) and not args.force:
|
||||
fail(prefix + 'Alignment file "{}" already existing - use --force to overwrite'.format(aligned))
|
||||
if tlog is None:
|
||||
if args.ignore_missing:
|
||||
return
|
||||
fail(prefix + 'Missing transcription log path')
|
||||
if not exists(audio) and not exists(tlog):
|
||||
if args.ignore_missing:
|
||||
return
|
||||
fail(prefix + 'Both audio file "{}" and transcription log "{}" are missing'.format(audio, tlog))
|
||||
if not exists(script):
|
||||
if args.ignore_missing:
|
||||
return
|
||||
fail(prefix + 'Missing script "{}"'.format(script))
|
||||
to_prepare.append((audio, tlog, script, aligned))
|
||||
|
||||
if (args.audio or args.tlog) and args.script and args.aligned and not args.catalog:
|
||||
enqueue_or_fail(args.audio, args.tlog, args.script, args.aligned)
|
||||
elif args.catalog:
|
||||
if not exists(args.catalog):
|
||||
fail('Unable to load catalog file "{}"'.format(args.catalog))
|
||||
catalog = path.abspath(args.catalog)
|
||||
catalog_dir = path.dirname(catalog)
|
||||
with open(catalog, 'r') as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
for entry in progress(catalog_entries, desc='Reading catalog'):
|
||||
enqueue_or_fail(resolve(catalog_dir, entry['audio']),
|
||||
resolve(catalog_dir, entry['tlog']),
|
||||
resolve(catalog_dir, entry['script']),
|
||||
resolve(catalog_dir, entry['aligned']),
|
||||
prefix='Problem loading catalog "{}" - '.format(catalog))
|
||||
else:
|
||||
fail('You have to either specify a combination of "--audio/--tlog,--script,--aligned" or "--catalog"')
|
||||
|
||||
logging.debug('Start')
|
||||
|
||||
to_align = []
|
||||
output_graph_path = None
|
||||
for audio_path, tlog_path, script_path, aligned_path in to_prepare:
|
||||
if not exists(tlog_path):
|
||||
if output_graph_path is None:
|
||||
logging.debug('Looking for model files in "{}"...'.format(model_dir))
|
||||
output_graph_path = glob(model_dir + "/output_graph.pb")[0]
|
||||
lang_lm_path = glob(model_dir + "/lm.binary")[0]
|
||||
lang_trie_path = glob(model_dir + "/trie")[0]
|
||||
kenlm_path = 'dependencies/kenlm/build/bin'
|
||||
if not path.exists(kenlm_path):
|
||||
kenlm_path = None
|
||||
deepspeech_path = 'dependencies/deepspeech'
|
||||
if not path.exists(deepspeech_path):
|
||||
deepspeech_path = None
|
||||
if kenlm_path and deepspeech_path and not args.stt_no_own_lm:
|
||||
tc = read_script(script_path)
|
||||
if not tc.clean_text.strip():
|
||||
logging.error('Cleaned transcript is empty for {}'.format(path.basename(script_path)))
|
||||
continue
|
||||
clean_text_path = script_path + '.clean'
|
||||
with open(clean_text_path, 'w') as clean_text_file:
|
||||
clean_text_file.write(tc.clean_text)
|
||||
|
||||
arpa_path = script_path + '.arpa'
|
||||
if not path.exists(arpa_path):
|
||||
subprocess.check_call([
|
||||
kenlm_path + '/lmplz',
|
||||
'--discount_fallback',
|
||||
'--text',
|
||||
clean_text_path,
|
||||
'--arpa',
|
||||
arpa_path,
|
||||
'--o',
|
||||
'5'
|
||||
])
|
||||
|
||||
lm_path = script_path + '.lm'
|
||||
if not path.exists(lm_path):
|
||||
subprocess.check_call([
|
||||
kenlm_path + '/build_binary',
|
||||
'-s',
|
||||
arpa_path,
|
||||
lm_path
|
||||
])
|
||||
|
||||
trie_path = script_path + '.trie'
|
||||
if not path.exists(trie_path):
|
||||
subprocess.check_call([
|
||||
deepspeech_path + '/generate_trie',
|
||||
alphabet_path,
|
||||
lm_path,
|
||||
trie_path
|
||||
])
|
||||
else:
|
||||
lm_path = lang_lm_path
|
||||
trie_path = lang_trie_path
|
||||
|
||||
logging.debug('Loading acoustic model from "{}", alphabet from "{}", trie from "{}" and language model from "{}"...'
|
||||
.format(output_graph_path, alphabet_path, trie_path, lm_path))
|
||||
|
||||
# Run VAD on the input file
|
||||
logging.debug('Transcribing VAD segments...')
|
||||
frames = read_frames_from_file(audio_path, model_format, args.audio_vad_frame_length)
|
||||
segments = vad_split(frames,
|
||||
model_format,
|
||||
num_padding_frames=args.audio_vad_padding,
|
||||
threshold=args.audio_vad_threshold,
|
||||
aggressiveness=args.audio_vad_aggressiveness)
|
||||
|
||||
def pre_filter():
|
||||
for i, segment in enumerate(segments):
|
||||
segment_buffer, time_start, time_end = segment
|
||||
time_length = time_end - time_start
|
||||
if args.stt_min_duration and time_length < args.stt_min_duration:
|
||||
logging.info('Fragment {}: Audio too short for STT'.format(i))
|
||||
continue
|
||||
if args.stt_max_duration and time_length > args.stt_max_duration:
|
||||
logging.info('Fragment {}: Audio too long for STT'.format(i))
|
||||
continue
|
||||
yield (time_start, time_end, np.frombuffer(segment_buffer, dtype=np.int16))
|
||||
|
||||
samples = list(progress(pre_filter(), desc='VAD splitting'))
|
||||
|
||||
pool = multiprocessing.Pool(initializer=init_stt,
|
||||
initargs=(output_graph_path, lm_path, trie_path),
|
||||
processes=args.stt_workers)
|
||||
transcripts = progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))
|
||||
|
||||
fragments = []
|
||||
for time_start, time_end, segment_transcript in transcripts:
|
||||
if segment_transcript is None:
|
||||
continue
|
||||
fragments.append({
|
||||
'start': time_start,
|
||||
'end': time_end,
|
||||
'transcript': segment_transcript
|
||||
})
|
||||
logging.debug('Excluded {} empty transcripts'.format(len(transcripts) - len(fragments)))
|
||||
|
||||
logging.debug('Writing transcription log to file "{}"...'.format(tlog_path))
|
||||
with open(tlog_path, 'w') as tlog_file:
|
||||
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None))
|
||||
if not path.isfile(tlog_path):
|
||||
fail('Problem loading transcript from "{}"'.format(tlog_path))
|
||||
to_align.append((tlog_path, script_path, aligned_path))
|
||||
|
||||
total_fragments = 0
|
||||
dropped_fragments = 0
|
||||
reasons = Counter()
|
||||
|
||||
index = 0
|
||||
pool = multiprocessing.Pool(processes=args.align_workers)
|
||||
for aligned_file, file_total_fragments, file_dropped_fragments, file_reasons in \
|
||||
progress(pool.imap_unordered(align, to_align), desc='Aligning', total=len(to_align)):
|
||||
if args.no_progress:
|
||||
index += 1
|
||||
logging.info('Aligned file {} of {} - wrote results to "{}"'.format(index, len(to_align), aligned_file))
|
||||
total_fragments += file_total_fragments
|
||||
dropped_fragments += file_dropped_fragments
|
||||
reasons += file_reasons
|
||||
|
||||
logging.info('Aligned {} fragments'.format(total_fragments))
|
||||
if total_fragments > 0 and dropped_fragments > 0:
|
||||
logging.info('Dropped {} fragments {:0.2f}%:'.format(dropped_fragments,
|
||||
dropped_fragments * 100.0 / total_fragments))
|
||||
for key, number in reasons.most_common():
|
||||
logging.info(' - {}: {}'.format(key, number))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
|
||||
|
||||
parser.add_argument('--audio', type=str,
|
||||
|
@ -395,10 +578,20 @@ def main():
|
|||
help='Path to an alphabet file (overriding the one from --stt-model-dir)')
|
||||
|
||||
audio_group = parser.add_argument_group(title='Audio pre-processing options')
|
||||
audio_group.add_argument('--audio-vad-aggressiveness', type=int, choices=range(4), required=False,
|
||||
help='Determines how aggressive filtering out non-speech is (default: 3)')
|
||||
audio_group.add_argument('--audio-vad-aggressiveness', type=int, choices=range(4), default=3,
|
||||
help='Aggressiveness of voice activity detection in a frame (default: 3)')
|
||||
audio_group.add_argument('--audio-vad-padding', type=int, default=10,
|
||||
help='Number of padding audio frames in VAD ring-buffer')
|
||||
audio_group.add_argument('--audio-vad-threshold', type=float, default=0.5,
|
||||
help='VAD ring-buffer threshold for voiced frames '
|
||||
'(e.g. 0.5 -> 50% of the ring-buffer frames have to be voiced '
|
||||
'for triggering a split)')
|
||||
audio_group.add_argument('--audio-vad-frame-length', choices=[10, 20, 30], default=30,
|
||||
help='VAD audio frame length in ms (10, 20 or 30)')
|
||||
|
||||
stt_group = parser.add_argument_group(title='STT options')
|
||||
stt_group.add_argument('--stt-model-rate', type=int, default=DEFAULT_RATE,
|
||||
help='Supported sample rate of the acoustic model')
|
||||
stt_group.add_argument('--stt-model-dir', required=False,
|
||||
help='Path to a directory with output_graph, lm, trie and (optional) alphabet file ' +
|
||||
'(default: "data/en"')
|
||||
|
@ -467,8 +660,8 @@ def main():
|
|||
output_group.add_argument('--output-pretty', action="store_true",
|
||||
help='Writes indented JSON output"')
|
||||
|
||||
for short in named_numbers.keys():
|
||||
long, atype, desc = named_numbers[short]
|
||||
for short in NAMED_NUMBERS.keys():
|
||||
long, atype, desc = NAMED_NUMBERS[short]
|
||||
desc = (' - value range: ' + desc) if desc else ''
|
||||
output_group.add_argument('--output-' + short.lower(), action="store_true",
|
||||
help='Writes {} ({}) to output'.format(long, short))
|
||||
|
@ -476,69 +669,14 @@ def main():
|
|||
output_group.add_argument('--output-' + extreme.lower() + '-' + short.lower(), type=atype, required=False,
|
||||
help='{}imum {} ({}) the STT transcript of the audio '
|
||||
'has to have when compared with the original text{}'
|
||||
.format(extreme, long, short, desc))
|
||||
.format(extreme, long, short, desc))
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
# Debug helpers
|
||||
logging.basicConfig(stream=sys.stdout, level=args.loglevel if args.loglevel else 20)
|
||||
|
||||
def progress(iter, **kwargs):
|
||||
return iter if args.no_progress else tqdm(iter, **kwargs)
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not path.isabs(spec_path):
|
||||
spec_path = path.join(base_path, spec_path)
|
||||
return spec_path
|
||||
|
||||
def exists(file_path):
|
||||
if file_path is None:
|
||||
return False
|
||||
return os.path.isfile(file_path)
|
||||
|
||||
to_prepare = []
|
||||
|
||||
def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
|
||||
if exists(aligned) and not args.force:
|
||||
fail(prefix + 'Alignment file "{}" already existing - use --force to overwrite'.format(aligned))
|
||||
if tlog is None:
|
||||
if args.ignore_missing:
|
||||
return
|
||||
fail(prefix + 'Missing transcription log path')
|
||||
if not exists(audio) and not exists(tlog):
|
||||
if args.ignore_missing:
|
||||
return
|
||||
fail(prefix + 'Both audio file "{}" and transcription log "{}" are missing'.format(audio, tlog))
|
||||
if not exists(script):
|
||||
if args.ignore_missing:
|
||||
return
|
||||
fail(prefix + 'Missing script "{}"'.format(script))
|
||||
to_prepare.append((audio, tlog, script, aligned))
|
||||
|
||||
if (args.audio or args.tlog) and args.script and args.aligned and not args.catalog:
|
||||
enqueue_or_fail(args.audio, args.tlog, args.script, args.aligned)
|
||||
elif args.catalog:
|
||||
if not exists(args.catalog):
|
||||
fail('Unable to load catalog file "{}"'.format(args.catalog))
|
||||
catalog = path.abspath(args.catalog)
|
||||
catalog_dir = path.dirname(catalog)
|
||||
with open(catalog, 'r') as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
for entry in progress(catalog_entries, desc='Reading catalog'):
|
||||
enqueue_or_fail(resolve(catalog_dir, entry['audio']),
|
||||
resolve(catalog_dir, entry['tlog']),
|
||||
resolve(catalog_dir, entry['script']),
|
||||
resolve(catalog_dir, entry['aligned']),
|
||||
prefix='Problem loading catalog "{}" - '.format(catalog))
|
||||
else:
|
||||
fail('You have to either specify a combination of "--audio/--tlog,--script,--aligned" or "--catalog"')
|
||||
|
||||
logging.debug('Start')
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
model_dir = os.path.expanduser(args.stt_model_dir if args.stt_model_dir else 'models/en')
|
||||
|
||||
if args.alphabet is not None:
|
||||
alphabet_path = args.alphabet
|
||||
else:
|
||||
|
@ -547,130 +685,5 @@ def main():
|
|||
fail('Found no alphabet file')
|
||||
logging.debug('Loading alphabet from "{}"...'.format(alphabet_path))
|
||||
alphabet = Alphabet(alphabet_path)
|
||||
|
||||
to_align = []
|
||||
output_graph_path = None
|
||||
for audio, tlog, script, aligned in to_prepare:
|
||||
if not exists(tlog):
|
||||
if output_graph_path is None:
|
||||
logging.debug('Looking for model files in "{}"...'.format(model_dir))
|
||||
output_graph_path, lang_lm_path, lang_trie_path = wavTranscriber.resolve_models(model_dir)
|
||||
kenlm_path = 'dependencies/kenlm/build/bin'
|
||||
if not path.exists(kenlm_path):
|
||||
kenlm_path = None
|
||||
deepspeech_path = 'dependencies/deepspeech'
|
||||
if not path.exists(deepspeech_path):
|
||||
deepspeech_path = None
|
||||
if kenlm_path and deepspeech_path and not args.stt_no_own_lm:
|
||||
tc = read_script(script)
|
||||
if not tc.clean_text.strip():
|
||||
logging.error('Cleaned transcript is empty for {}'.format(path.basename(script)))
|
||||
continue
|
||||
clean_text_path = script + '.clean'
|
||||
with open(clean_text_path, 'w') as clean_text_file:
|
||||
clean_text_file.write(tc.clean_text)
|
||||
|
||||
arpa_path = script + '.arpa'
|
||||
if not path.exists(arpa_path):
|
||||
subprocess.check_call([
|
||||
kenlm_path + '/lmplz',
|
||||
'--discount_fallback',
|
||||
'--text',
|
||||
clean_text_path,
|
||||
'--arpa',
|
||||
arpa_path,
|
||||
'--o',
|
||||
'5'
|
||||
])
|
||||
|
||||
lm_path = script + '.lm'
|
||||
if not path.exists(lm_path):
|
||||
subprocess.check_call([
|
||||
kenlm_path + '/build_binary',
|
||||
'-s',
|
||||
arpa_path,
|
||||
lm_path
|
||||
])
|
||||
|
||||
trie_path = script + '.trie'
|
||||
if not path.exists(trie_path):
|
||||
subprocess.check_call([
|
||||
deepspeech_path + '/generate_trie',
|
||||
alphabet_path,
|
||||
lm_path,
|
||||
trie_path
|
||||
])
|
||||
else:
|
||||
lm_path = lang_lm_path
|
||||
trie_path = lang_trie_path
|
||||
|
||||
logging.debug('Loading acoustic model from "{}", alphabet from "{}", trie from "{}" and language model from "{}"...'
|
||||
.format(output_graph_path, alphabet_path, trie_path, lm_path))
|
||||
|
||||
# Run VAD on the input file
|
||||
logging.debug('Transcribing VAD segments...')
|
||||
aggressiveness = int(args.audio_vad_aggressiveness) if args.audio_vad_aggressiveness else 3
|
||||
segments, rate, audio_length = wavSplit.vad_segment_generator(audio, aggressiveness)
|
||||
|
||||
def pre_filter():
|
||||
for i, segment in enumerate(segments):
|
||||
segment_buffer, time_start, time_end = segment
|
||||
time_length = time_end - time_start
|
||||
if args.stt_min_duration and time_length < args.stt_min_duration:
|
||||
logging.info('Fragment {}: Audio too short for STT'.format(i))
|
||||
continue
|
||||
if args.stt_max_duration and time_length > args.stt_max_duration:
|
||||
logging.info('Fragment {}: Audio too long for STT'.format(i))
|
||||
continue
|
||||
yield (time_start, time_end, np.frombuffer(segment_buffer, dtype=np.int16))
|
||||
|
||||
samples = list(progress(pre_filter(), desc='VAD splitting'))
|
||||
|
||||
pool = multiprocessing.Pool(initializer=init_stt,
|
||||
initargs=(output_graph_path, alphabet_path, lm_path, trie_path, rate),
|
||||
processes=args.stt_workers)
|
||||
transcripts = progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))
|
||||
|
||||
fragments = []
|
||||
for time_start, time_end, segment_transcript in transcripts:
|
||||
if segment_transcript is None:
|
||||
continue
|
||||
fragments.append({
|
||||
'start': time_start,
|
||||
'end': time_end,
|
||||
'transcript': segment_transcript
|
||||
})
|
||||
logging.debug('Excluded {} empty transcripts'.format(len(transcripts) - len(fragments)))
|
||||
|
||||
logging.debug('Writing transcription log to file "{}"...'.format(tlog))
|
||||
with open(tlog, 'w') as tlog_file:
|
||||
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None))
|
||||
if not path.isfile(tlog):
|
||||
fail('Problem loading transcript from "{}"'.format(tlog))
|
||||
to_align.append((tlog, script, aligned))
|
||||
|
||||
total_fragments = 0
|
||||
dropped_fragments = 0
|
||||
reasons = Counter()
|
||||
|
||||
index = 0
|
||||
pool = multiprocessing.Pool(initializer=init_align, initargs=(args, alphabet), processes=args.align_workers)
|
||||
for aligned_file, file_total_fragments, file_dropped_fragments, file_reasons in \
|
||||
progress(pool.imap_unordered(align, to_align), desc='Aligning', total=len(to_align)):
|
||||
if args.no_progress:
|
||||
index += 1
|
||||
logging.info('Aligned file {} of {} - wrote results to "{}"'.format(index, len(to_align), aligned_file))
|
||||
total_fragments += file_total_fragments
|
||||
dropped_fragments += file_dropped_fragments
|
||||
reasons += file_reasons
|
||||
|
||||
logging.info('Aligned {} fragments'.format(total_fragments))
|
||||
if total_fragments > 0 and dropped_fragments > 0:
|
||||
logging.info('Dropped {} fragments {:0.2f}%:'.format(dropped_fragments,
|
||||
dropped_fragments * 100.0 / total_fragments))
|
||||
for key, number in reasons.most_common():
|
||||
logging.info(' - {}: {}'.format(key, number))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_format = (args.stt_model_rate, 1, 2)
|
||||
main()
|
||||
|
|
305
align/audio.py
305
align/audio.py
|
@ -1,25 +1,110 @@
|
|||
import os
|
||||
import io
|
||||
import sox
|
||||
import wave
|
||||
import opuslib
|
||||
import tempfile
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from webrtcvad import Vad
|
||||
from utils import LimitingPool
|
||||
|
||||
DEFAULT_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
||||
DEFAULT_WIDTH = 2
|
||||
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
|
||||
|
||||
AUDIO_TYPE_NP = 'np'
|
||||
AUDIO_TYPE_PCM = 'pcm'
|
||||
AUDIO_FILE_PREFIX = 'audio/'
|
||||
AUDIO_TYPE_WAV = AUDIO_FILE_PREFIX + 'wav'
|
||||
AUDIO_TYPE_OPUS = AUDIO_FILE_PREFIX + 'opus'
|
||||
LOADABLE_FILE_FORMATS = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
|
||||
|
||||
def get_audio_format(wav_file):
|
||||
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
|
||||
OPUS_PCM_LEN_SIZE = 4
|
||||
OPUS_RATE_SIZE = 4
|
||||
OPUS_CHANNELS_SIZE = 1
|
||||
OPUS_WIDTH_SIZE = 1
|
||||
OPUS_CHUNK_LEN_SIZE = 2
|
||||
|
||||
NP_TYPE_LOOKUP = [None, np.int8, np.int16, None, np.int32]
|
||||
UNSUPPORTED_TYPE = 'Unsupported audio type: {}'
|
||||
|
||||
|
||||
def set_audio_format(wav_file, audio_format=DEFAULT_FORMAT):
|
||||
class Sample:
|
||||
def __init__(self, audio_type, raw_data, audio_format=None):
|
||||
self.audio_type = audio_type
|
||||
self.audio_format = audio_format
|
||||
if audio_type in LOADABLE_FILE_FORMATS:
|
||||
self.audio = io.BytesIO(raw_data)
|
||||
self.duration = read_duration(audio_type, self.audio)
|
||||
else:
|
||||
self.audio = raw_data
|
||||
if self.audio_format is None:
|
||||
raise ValueError('For audio type "{}" parameter "audio_format" is mandatory')
|
||||
if audio_type == AUDIO_TYPE_PCM:
|
||||
self.duration = get_pcm_duration(len(self.audio), self.audio_format)
|
||||
elif audio_type == AUDIO_TYPE_NP:
|
||||
self.duration = get_np_duration(len(self.audio), self.audio_format)
|
||||
else:
|
||||
raise ValueError(UNSUPPORTED_TYPE.format(self.audio_type))
|
||||
|
||||
def convert(self, new_audio_type):
|
||||
if self.audio_type == new_audio_type:
|
||||
return
|
||||
if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in LOADABLE_FILE_FORMATS:
|
||||
self.audio_format, audio = read_audio(self.audio_type, self.audio)
|
||||
self.audio.close()
|
||||
self.audio = audio
|
||||
elif new_audio_type == AUDIO_TYPE_NP:
|
||||
self.convert(AUDIO_TYPE_PCM)
|
||||
self.audio = pcm_to_np(self.audio_format, self.audio)
|
||||
elif new_audio_type in LOADABLE_FILE_FORMATS:
|
||||
self.convert(AUDIO_TYPE_PCM)
|
||||
audio_bytes = io.BytesIO()
|
||||
write_audio(new_audio_type, audio_bytes, self.audio_format, self.audio)
|
||||
audio_bytes.seek(0)
|
||||
self.audio = audio_bytes
|
||||
else:
|
||||
raise RuntimeError('Audio conversion from "{}" to "{}" not supported'
|
||||
.format(self.audio_type, new_audio_type))
|
||||
self.audio_type = new_audio_type
|
||||
|
||||
|
||||
def convert_samples(samples, audio_type=AUDIO_TYPE_PCM, processes=None):
|
||||
def convert_sample(sample):
|
||||
sample.convert(audio_type)
|
||||
return sample
|
||||
with LimitingPool(processes=processes) as pool:
|
||||
for current_sample in pool.map(convert_sample, samples):
|
||||
yield current_sample
|
||||
|
||||
|
||||
def write_audio_format_to_wav_file(wav_file, audio_format=DEFAULT_FORMAT):
|
||||
rate, channels, width = audio_format
|
||||
wav_file.setframerate(rate)
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(width)
|
||||
|
||||
|
||||
def read_audio_format_from_wav_file(wav_file):
|
||||
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
|
||||
|
||||
|
||||
def get_num_samples(pcm_len, audio_format=DEFAULT_FORMAT):
|
||||
_, channels, width = audio_format
|
||||
return pcm_len // (channels * width)
|
||||
|
||||
|
||||
def get_pcm_duration(pcm_len, audio_format=DEFAULT_FORMAT):
|
||||
return get_num_samples(pcm_len, audio_format) / audio_format[0]
|
||||
|
||||
|
||||
def get_np_duration(np_len, audio_format=DEFAULT_FORMAT):
|
||||
return np_len / audio_format[0]
|
||||
|
||||
|
||||
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
|
||||
sample_rate, channels, width = audio_format
|
||||
transformer = sox.Transformer()
|
||||
|
@ -30,7 +115,7 @@ def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=D
|
|||
def ensure_wav_with_format(src_audio_path, audio_format=DEFAULT_FORMAT):
|
||||
if src_audio_path.endswith('.wav'):
|
||||
with wave.open(src_audio_path, 'r') as src_audio_file:
|
||||
if get_audio_format(src_audio_file) == audio_format:
|
||||
if read_audio_format_from_wav_file(src_audio_file) == audio_format:
|
||||
return src_audio_path, False
|
||||
fd, tmp_file_path = tempfile.mkstemp(suffix='.wav')
|
||||
os.close(fd)
|
||||
|
@ -43,3 +128,215 @@ def extract_audio(audio_file, start, end):
|
|||
rate = audio_file.getframerate()
|
||||
audio_file.setpos(int(start * rate))
|
||||
return audio_file.readframes(int((end - start) * rate))
|
||||
|
||||
|
||||
class AudioFile:
|
||||
def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT):
|
||||
self.audio_path = audio_path
|
||||
self.audio_format = audio_format
|
||||
self.as_path = as_path
|
||||
self.open_file = None
|
||||
self.tmp_file_path = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.audio_path.endswith('.wav'):
|
||||
self.open_file = wave.open(self.audio_path, 'r')
|
||||
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
|
||||
if self.as_path:
|
||||
self.open_file.close()
|
||||
return self.audio_path
|
||||
return self.open_file
|
||||
self.open_file.close()
|
||||
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
|
||||
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
|
||||
if self.as_path:
|
||||
return self.tmp_file_path
|
||||
self.open_file = wave.open(self.tmp_file_path, 'r')
|
||||
return self.open_file
|
||||
|
||||
def __exit__(self, *args):
|
||||
if not self.as_path:
|
||||
self.open_file.close()
|
||||
if self.tmp_file_path is not None:
|
||||
os.remove(self.tmp_file_path)
|
||||
|
||||
|
||||
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
|
||||
audio_format = read_audio_format_from_wav_file(wav_file)
|
||||
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
|
||||
while True:
|
||||
try:
|
||||
data = wav_file.readframes(frame_size)
|
||||
if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms:
|
||||
break
|
||||
yield data
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False):
|
||||
with AudioFile(audio_path, audio_format=audio_format) as wav_file:
|
||||
for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder):
|
||||
yield frame
|
||||
|
||||
|
||||
def vad_split(audio_frames,
|
||||
audio_format=DEFAULT_FORMAT,
|
||||
num_padding_frames=10,
|
||||
threshold=0.5,
|
||||
aggressiveness=3):
|
||||
sample_rate, channels, width = audio_format
|
||||
if channels != 1:
|
||||
raise ValueError('VAD-splitting requires mono samples')
|
||||
if width != 2:
|
||||
raise ValueError('VAD-splitting requires 16 bit samples')
|
||||
if sample_rate not in [8000, 16000, 32000, 48000]:
|
||||
raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000')
|
||||
if aggressiveness not in [0, 1, 2, 3]:
|
||||
raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3')
|
||||
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||
triggered = False
|
||||
vad = Vad(int(aggressiveness))
|
||||
voiced_frames = []
|
||||
frame_duration_ms = 0
|
||||
frame_index = 0
|
||||
for frame_index, frame in enumerate(audio_frames):
|
||||
frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000
|
||||
if int(frame_duration_ms) not in [10, 20, 30]:
|
||||
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
|
||||
is_speech = vad.is_speech(frame, sample_rate)
|
||||
if not triggered:
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||
if num_voiced > threshold * ring_buffer.maxlen:
|
||||
triggered = True
|
||||
for f, s in ring_buffer:
|
||||
voiced_frames.append(f)
|
||||
ring_buffer.clear()
|
||||
else:
|
||||
voiced_frames.append(frame)
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||
triggered = False
|
||||
yield b''.join(voiced_frames), \
|
||||
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * frame_index
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
if len(voiced_frames) > 0:
|
||||
yield b''.join(voiced_frames), \
|
||||
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * (frame_index + 1)
|
||||
|
||||
|
||||
def pack_number(n, num_bytes):
|
||||
return n.to_bytes(num_bytes, 'big', signed=False)
|
||||
|
||||
|
||||
def unpack_number(data):
|
||||
return int.from_bytes(data, 'big', signed=False)
|
||||
|
||||
|
||||
def get_opus_frame_size(rate):
|
||||
return 60 * rate // 1000
|
||||
|
||||
|
||||
def write_opus(opus_file, audio_format, audio_data):
|
||||
rate, channels, width = audio_format
|
||||
frame_size = get_opus_frame_size(rate)
|
||||
encoder = opuslib.Encoder(rate, channels, opuslib.APPLICATION_AUDIO)
|
||||
chunk_size = frame_size * channels * width
|
||||
opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE))
|
||||
opus_file.write(pack_number(rate, OPUS_RATE_SIZE))
|
||||
opus_file.write(pack_number(channels, OPUS_CHANNELS_SIZE))
|
||||
opus_file.write(pack_number(width, OPUS_WIDTH_SIZE))
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i:i + chunk_size]
|
||||
encoded = encoder.encode(chunk, frame_size)
|
||||
opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE))
|
||||
opus_file.write(encoded)
|
||||
|
||||
|
||||
def read_opus_header(opus_file):
|
||||
opus_file.seek(0)
|
||||
pcm_len = unpack_number(opus_file.read(OPUS_PCM_LEN_SIZE))
|
||||
rate = unpack_number(opus_file.read(OPUS_RATE_SIZE))
|
||||
channels = unpack_number(opus_file.read(OPUS_CHANNELS_SIZE))
|
||||
width = unpack_number(opus_file.read(OPUS_WIDTH_SIZE))
|
||||
return pcm_len, (rate, channels, width)
|
||||
|
||||
|
||||
def read_opus(opus_file):
|
||||
pcm_len, audio_format = read_opus_header(opus_file)
|
||||
rate, channels, _ = audio_format
|
||||
frame_size = get_opus_frame_size(rate)
|
||||
decoder = opuslib.Decoder(rate, channels)
|
||||
audio_data = bytearray()
|
||||
while len(audio_data) < pcm_len:
|
||||
chunk_len = unpack_number(opus_file.read(OPUS_CHUNK_LEN_SIZE))
|
||||
chunk = opus_file.read(chunk_len)
|
||||
decoded = decoder.decode(chunk, frame_size)
|
||||
audio_data.extend(decoded)
|
||||
audio_data = audio_data[:pcm_len]
|
||||
return audio_format, audio_data
|
||||
|
||||
|
||||
def write_wav(wav_file, audio_format, pcm_data):
|
||||
with wave.open(wav_file, 'wb') as wav_file_writer:
|
||||
write_audio_format_to_wav_file(wav_file_writer, audio_format)
|
||||
wav_file_writer.writeframes(pcm_data)
|
||||
|
||||
|
||||
def read_wav(wav_file):
|
||||
wav_file.seek(0)
|
||||
with wave.open(wav_file, 'rb') as wav_file_reader:
|
||||
audio_format = read_audio_format_from_wav_file(wav_file_reader)
|
||||
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
|
||||
return audio_format, pcm_data
|
||||
|
||||
|
||||
def read_audio(audio_type, audio_file):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return read_wav(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus(audio_file)
|
||||
raise ValueError(UNSUPPORTED_TYPE.format(audio_type))
|
||||
|
||||
|
||||
def write_audio(audio_type, audio_file, audio_format, pcm_data):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return write_wav(audio_file, audio_format, pcm_data)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return write_opus(audio_file, audio_format, pcm_data)
|
||||
raise ValueError(UNSUPPORTED_TYPE.format(audio_type))
|
||||
|
||||
|
||||
def read_wav_duration(wav_file):
|
||||
wav_file.seek(0)
|
||||
with wave.open(wav_file, 'rb') as wav_file_reader:
|
||||
return wav_file_reader.getnframes() / wav_file_reader.getframerate()
|
||||
|
||||
|
||||
def read_opus_duration(opus_file):
|
||||
pcm_len, audio_format = read_opus_header(opus_file)
|
||||
return get_pcm_duration(pcm_len, audio_format)
|
||||
|
||||
|
||||
def read_duration(audio_type, audio_file):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return read_wav_duration(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus_duration(audio_file)
|
||||
raise ValueError(UNSUPPORTED_TYPE.format(audio_type))
|
||||
|
||||
|
||||
def pcm_to_np(audio_format, pcm_data):
|
||||
_, channels, width = audio_format
|
||||
if width < 1 or width > 4 or width == 3:
|
||||
raise ValueError('Unsupported sample width: {}'.format(width))
|
||||
dtype = NP_TYPE_LOOKUP[width]
|
||||
samples = np.frombuffer(pcm_data, dtype=dtype)
|
||||
samples = samples[::channels] # limited to mono for now
|
||||
samples = samples.astype(np.float32) / np.iinfo(dtype).max
|
||||
return np.expand_dims(samples, axis=1)
|
||||
|
|
|
@ -16,7 +16,7 @@ from tqdm import tqdm
|
|||
from datetime import timedelta
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
from audio import DEFAULT_FORMAT, ensure_wav_with_format, extract_audio, set_audio_format
|
||||
from audio import DEFAULT_FORMAT, ensure_wav_with_format, extract_audio, write_audio_format_to_wav_file
|
||||
|
||||
audio_format = DEFAULT_FORMAT
|
||||
unknown = '<unknown>'
|
||||
|
@ -430,7 +430,7 @@ def main(args):
|
|||
sample_path = '{}/sample-{:010d}.wav'.format(fragment['list-name'], len(group_list))
|
||||
with TargetFile(sample_path, "wb") as base_wav_file:
|
||||
with wave.open(base_wav_file, 'wb') as wav_file:
|
||||
set_audio_format(wav_file)
|
||||
write_audio_format_to_wav_file(wav_file)
|
||||
wav_file.writeframes(audio_segment)
|
||||
file_size = base_wav_file.tell()
|
||||
group_list.append((sample_path, file_size, fragment))
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
|
||||
def circulate(items, center=None):
|
||||
count = len(list(items))
|
||||
if count > 0:
|
||||
|
@ -57,3 +60,30 @@ def greedy_minimum_search(a, b, compute, result_a=None, result_b=None):
|
|||
return greedy_minimum_search(a, c, compute, result_a=result_a)
|
||||
else:
|
||||
return greedy_minimum_search(c, b, compute, result_b=result_b)
|
||||
|
||||
|
||||
class LimitingPool:
|
||||
def __init__(self, processes=None, limit_factor=2, sleeping_for=0.1):
|
||||
self.processes = os.cpu_count() if processes is None else processes
|
||||
self.pool = ThreadPool(processes=processes)
|
||||
self.sleeping_for = sleeping_for
|
||||
self.max_ahead = self.processes * limit_factor
|
||||
self.processed = 0
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def limit(self, it):
|
||||
for obj in it:
|
||||
while self.processed >= self.max_ahead:
|
||||
time.sleep(self.sleeping_for)
|
||||
self.processed += 1
|
||||
yield obj
|
||||
|
||||
def map(self, fun, it):
|
||||
for obj in self.pool.imap(fun, self.limit(it)):
|
||||
self.processed -= 1
|
||||
yield obj
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.pool.close()
|
||||
|
|
|
@ -1,128 +0,0 @@
|
|||
import collections
|
||||
from webrtcvad import Vad
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
class Frame(object):
|
||||
"""Represents a "frame" of audio data."""
|
||||
def __init__(self, bytes, timestamp, duration):
|
||||
self.bytes = bytes
|
||||
self.timestamp = timestamp
|
||||
self.duration = duration
|
||||
|
||||
|
||||
def frame_generator(frame_duration_ms, audio, sample_rate):
|
||||
"""Generates audio frames from PCM audio data.
|
||||
|
||||
Takes the desired frame duration in milliseconds, the PCM data, and
|
||||
the sample rate.
|
||||
|
||||
Yields Frames of the requested duration.
|
||||
"""
|
||||
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
|
||||
offset = 0
|
||||
timestamp = 0.0
|
||||
duration = (float(n) / sample_rate) / 2.0
|
||||
while offset + n < len(audio):
|
||||
yield Frame(audio[offset:offset + n], timestamp, duration)
|
||||
timestamp += duration
|
||||
offset += n
|
||||
|
||||
|
||||
def vad_collector(sample_rate, frame_duration_ms,
|
||||
padding_duration_ms, threshold, vad, frames):
|
||||
"""Filters out non-voiced audio frames.
|
||||
|
||||
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||
the voiced audio.
|
||||
|
||||
Uses a padded, sliding window algorithm over the audio frames.
|
||||
When more than 90% of the frames in the window are voiced (as
|
||||
reported by the VAD), the collector triggers and begins yielding
|
||||
audio frames. Then the collector waits until 90% of the frames in
|
||||
the window are unvoiced to detrigger.
|
||||
|
||||
The window is padded at the front and back to provide a small
|
||||
amount of silence or the beginnings/endings of speech around the
|
||||
voiced frames.
|
||||
|
||||
Arguments:
|
||||
|
||||
sample_rate - The audio sample rate, in Hz.
|
||||
frame_duration_ms - The frame duration in milliseconds.
|
||||
padding_duration_ms - The amount to pad the window, in milliseconds.
|
||||
vad - An instance of webrtcvad.Vad.
|
||||
frames - a source of audio frames (sequence or generator).
|
||||
|
||||
Returns: A generator that yields PCM audio data.
|
||||
"""
|
||||
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
||||
# We use a deque for our sliding window/ring buffer.
|
||||
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
|
||||
# NOTTRIGGERED state.
|
||||
triggered = False
|
||||
|
||||
voiced_frames = []
|
||||
for frame_index, frame in enumerate(frames):
|
||||
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
||||
|
||||
if not triggered:
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||
# the ring buffer are voiced frames, then enter the
|
||||
# TRIGGERED state.
|
||||
if num_voiced > threshold * ring_buffer.maxlen:
|
||||
triggered = True
|
||||
# We want to yield all the audio we see from now until
|
||||
# we are NOTTRIGGERED, but we have to start with the
|
||||
# audio that's already in the ring buffer.
|
||||
for f, s in ring_buffer:
|
||||
voiced_frames.append(f)
|
||||
ring_buffer.clear()
|
||||
else:
|
||||
# We're in the TRIGGERED state, so collect the audio data
|
||||
# and add it to the ring buffer.
|
||||
voiced_frames.append(frame)
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||
# If more than 90% of the frames in the ring buffer are
|
||||
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||
# audio we've collected.
|
||||
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||
triggered = False
|
||||
yield b''.join([f.bytes for f in voiced_frames]), \
|
||||
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * frame_index
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
|
||||
if triggered:
|
||||
pass
|
||||
# If we have any leftover voiced audio when we run out of input,
|
||||
# yield it.
|
||||
if voiced_frames:
|
||||
yield b''.join([f.bytes for f in voiced_frames]), \
|
||||
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * (frame_index + 1)
|
||||
|
||||
|
||||
def vad_segment_generator(audio_file, aggressiveness):
|
||||
"""
|
||||
Generate VAD segments. Filters out non-voiced audio frames.
|
||||
:param audio_file: Input audio file to run VAD on.
|
||||
:param aggressiveness: How aggressive filtering out non-speech is (between 0 and 3)
|
||||
:return: Returns tuple of
|
||||
segments: a bytearray of multiple smaller audio frames
|
||||
(The longer audio split into multiple smaller one's)
|
||||
sample_rate: Sample rate of the input audio file
|
||||
audio_length: Duration of the input audio file
|
||||
"""
|
||||
audio = (AudioSegment.from_file(audio_file)
|
||||
.set_channels(1)
|
||||
.set_frame_rate(16000))
|
||||
vad = Vad(int(aggressiveness))
|
||||
frames = frame_generator(30, audio.raw_data, audio.frame_rate)
|
||||
segments = vad_collector(audio.frame_rate, 30, 300, 0.5, vad, frames)
|
||||
return segments, audio.frame_rate, audio.duration_seconds * 1000
|
|
@ -1,51 +0,0 @@
|
|||
import glob
|
||||
from deepspeech import Model
|
||||
|
||||
|
||||
def load_model(models, alphabet, lm, trie):
|
||||
"""
|
||||
Load the pre-trained model into the memory
|
||||
:param models: Output Graph Protocol Buffer file
|
||||
:param alphabet: Alphabet.txt file
|
||||
:param lm: Language model file
|
||||
:param trie: Trie file
|
||||
:return: tuple (DeepSpeech object, Model Load Time, LM Load Time)
|
||||
"""
|
||||
N_FEATURES = 26
|
||||
N_CONTEXT = 9
|
||||
BEAM_WIDTH = 500
|
||||
#LM_ALPHA = 0.75
|
||||
#LM_BETA = 1.85
|
||||
|
||||
LM_ALPHA = 1
|
||||
LM_BETA = 1.85
|
||||
|
||||
ds = Model(models, BEAM_WIDTH)
|
||||
ds.enableDecoderWithLM(lm, trie, LM_ALPHA, LM_BETA)
|
||||
return ds
|
||||
|
||||
|
||||
def stt(ds, audio, fs):
|
||||
"""
|
||||
Run Inference on input audio file
|
||||
:param ds: DeepSpeech object
|
||||
:param audio: Input audio for running inference on
|
||||
:param fs: Sample rate of the input audio file
|
||||
:return: tuple (Inference result text, Inference time)
|
||||
"""
|
||||
audio_length = len(audio) * (1 / 16000)
|
||||
# Run DeepSpeech
|
||||
output = ds.stt(audio)
|
||||
return output
|
||||
|
||||
|
||||
def resolve_models(dir_name):
|
||||
"""
|
||||
Resolve directory path for the models and fetch each of them.
|
||||
:param dir_name: Path to the directory containing pre-trained models
|
||||
:return: tuple containing each of the model files (pb, alphabet, lm and trie)
|
||||
"""
|
||||
pb = glob.glob(dir_name + "/*.pb")[0]
|
||||
lm = glob.glob(dir_name + "/lm.binary")[0]
|
||||
trie = glob.glob(dir_name + "/trie")[0]
|
||||
return pb, lm, trie
|
Загрузка…
Ссылка в новой задаче