Light refactoring to integrate DeepSpeech audio helpers. Additional VAD parameters.

This commit is contained in:
Tilman Kamp 2020-02-06 17:06:44 +01:00
Родитель d17489c29e
Коммит fba9c8971d
6 изменённых файлов: 566 добавлений и 405 удалений

Просмотреть файл

@ -3,40 +3,40 @@ import sys
import json
import logging
import argparse
import deepspeech
import subprocess
import os.path as path
import numpy as np
import textdistance
import wavSplit
import wavTranscriber
import multiprocessing
from collections import Counter
from search import FuzzySearch
from glob import glob
from tqdm import tqdm
from text import Alphabet, TextCleaner, levenshtein, similarity
from utils import enweight
from audio import DEFAULT_RATE, read_frames_from_file, vad_split
algos = ['WNG', 'jaro_winkler', 'editex', 'levenshtein', 'mra', 'hamming']
sim_desc = 'From 0.0 (not equal at all) to 100.0 (totally equal)'
named_numbers = {
BEAM_WIDTH = 500
LM_ALPHA = 1
LM_BETA = 1.85
ALGORITHMS = ['WNG', 'jaro_winkler', 'editex', 'levenshtein', 'mra', 'hamming']
SIM_DESC = 'From 0.0 (not equal at all) to 100.0 (totally equal)'
NAMED_NUMBERS = {
'tlen': ('transcript length', int, None),
'mlen': ('match length', int, None),
'SWS': ('Smith-Waterman score', float, 'From 0.0 (not equal at all) to 100.0+ (pretty equal)'),
'WNG': ('weighted N-gram similarity', float, sim_desc),
'jaro_winkler': ('Jaro-Winkler similarity', float, sim_desc),
'editex': ('Editex similarity', float, sim_desc),
'levenshtein': ('Levenshtein similarity', float, sim_desc),
'mra': ('MRA similarity', float, sim_desc),
'hamming': ('Hamming similarity', float, sim_desc),
'WNG': ('weighted N-gram similarity', float, SIM_DESC),
'jaro_winkler': ('Jaro-Winkler similarity', float, SIM_DESC),
'editex': ('Editex similarity', float, SIM_DESC),
'levenshtein': ('Levenshtein similarity', float, SIM_DESC),
'mra': ('MRA similarity', float, SIM_DESC),
'hamming': ('Hamming similarity', float, SIM_DESC),
'CER': ('character error rate', float, 'From 0.0 (no different words) to 100.0+ (total miss)'),
'WER': ('word error rate', float, 'From 0.0 (no wrong characters) to 100.0+ (total miss)')
}
args = None
model = None
sample_rate = 0
alphabet = None
def fail(message, code=1):
logging.fatal(message)
@ -61,27 +61,23 @@ def read_script(script_path):
return tc
def init_stt(output_graph_path, alphabet_path, lm_path, trie_path, rate):
global model, sample_rate
sample_rate = rate
model = None
def init_stt(output_graph_path, lm_path, trie_path):
global model
model = deepspeech.Model(output_graph_path, BEAM_WIDTH)
model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA)
logging.debug('Process {}: Loaded models'.format(os.getpid()))
model = wavTranscriber.load_model(output_graph_path, alphabet_path, lm_path, trie_path)
def stt(sample):
time_start, time_end, audio = sample
logging.debug('Process {}: Transcribing...'.format(os.getpid()))
transcript = wavTranscriber.stt(model, audio, sample_rate)
transcript = model.stt(audio)
logging.debug('Process {}: {}'.format(os.getpid(), transcript))
return time_start, time_end, ' '.join(transcript.split())
def init_align(w_args, w_alphabet):
global args, alphabet
args = w_args
alphabet = w_alphabet
def align(triple):
tlog, script, aligned = triple
@ -166,7 +162,7 @@ def align(triple):
max_ngram_size=args.align_wng_max_size,
size_factor=args.align_wng_size_factor,
position_factor=args.align_wng_position_factor)
elif algo in algos:
elif algo in ALGORITHMS:
algo_impl = similarity_algos[algo] = getattr(textdistance, algo).normalized_similarity
else:
logging.fatal('Unknown similarity metric "{}"'.format(algo))
@ -257,7 +253,7 @@ def align(triple):
show.insert(0, '{}: {:.2f}'.format(number_key, val))
if should_output:
fragment[kl] = val
reason_base = '{} ({})'.format(named_numbers[number_key][0], number_key)
reason_base = '{} ({})'.format(NAMED_NUMBERS[number_key][0], number_key)
reason = None
if min_val and val < min_val:
reason = reason_base + ' too low'
@ -321,7 +317,7 @@ def align(triple):
continue
should_skip = False
for algo in algos:
for algo in ALGORITHMS:
should_skip = should_skip or apply_number(algo, index, result_fragment, sample_numbers,
lambda: 100 * phrase_similarity(algo,
fragment_matched,
@ -361,7 +357,194 @@ def align(triple):
def main():
global args, alphabet
# Debug helpers
logging.basicConfig(stream=sys.stdout, level=args.loglevel if args.loglevel else 20)
def progress(iter, **kwargs):
return iter if args.no_progress else tqdm(iter, **kwargs)
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not path.isabs(spec_path):
spec_path = path.join(base_path, spec_path)
return spec_path
def exists(file_path):
if file_path is None:
return False
return os.path.isfile(file_path)
to_prepare = []
def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
if exists(aligned) and not args.force:
fail(prefix + 'Alignment file "{}" already existing - use --force to overwrite'.format(aligned))
if tlog is None:
if args.ignore_missing:
return
fail(prefix + 'Missing transcription log path')
if not exists(audio) and not exists(tlog):
if args.ignore_missing:
return
fail(prefix + 'Both audio file "{}" and transcription log "{}" are missing'.format(audio, tlog))
if not exists(script):
if args.ignore_missing:
return
fail(prefix + 'Missing script "{}"'.format(script))
to_prepare.append((audio, tlog, script, aligned))
if (args.audio or args.tlog) and args.script and args.aligned and not args.catalog:
enqueue_or_fail(args.audio, args.tlog, args.script, args.aligned)
elif args.catalog:
if not exists(args.catalog):
fail('Unable to load catalog file "{}"'.format(args.catalog))
catalog = path.abspath(args.catalog)
catalog_dir = path.dirname(catalog)
with open(catalog, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
for entry in progress(catalog_entries, desc='Reading catalog'):
enqueue_or_fail(resolve(catalog_dir, entry['audio']),
resolve(catalog_dir, entry['tlog']),
resolve(catalog_dir, entry['script']),
resolve(catalog_dir, entry['aligned']),
prefix='Problem loading catalog "{}" - '.format(catalog))
else:
fail('You have to either specify a combination of "--audio/--tlog,--script,--aligned" or "--catalog"')
logging.debug('Start')
to_align = []
output_graph_path = None
for audio_path, tlog_path, script_path, aligned_path in to_prepare:
if not exists(tlog_path):
if output_graph_path is None:
logging.debug('Looking for model files in "{}"...'.format(model_dir))
output_graph_path = glob(model_dir + "/output_graph.pb")[0]
lang_lm_path = glob(model_dir + "/lm.binary")[0]
lang_trie_path = glob(model_dir + "/trie")[0]
kenlm_path = 'dependencies/kenlm/build/bin'
if not path.exists(kenlm_path):
kenlm_path = None
deepspeech_path = 'dependencies/deepspeech'
if not path.exists(deepspeech_path):
deepspeech_path = None
if kenlm_path and deepspeech_path and not args.stt_no_own_lm:
tc = read_script(script_path)
if not tc.clean_text.strip():
logging.error('Cleaned transcript is empty for {}'.format(path.basename(script_path)))
continue
clean_text_path = script_path + '.clean'
with open(clean_text_path, 'w') as clean_text_file:
clean_text_file.write(tc.clean_text)
arpa_path = script_path + '.arpa'
if not path.exists(arpa_path):
subprocess.check_call([
kenlm_path + '/lmplz',
'--discount_fallback',
'--text',
clean_text_path,
'--arpa',
arpa_path,
'--o',
'5'
])
lm_path = script_path + '.lm'
if not path.exists(lm_path):
subprocess.check_call([
kenlm_path + '/build_binary',
'-s',
arpa_path,
lm_path
])
trie_path = script_path + '.trie'
if not path.exists(trie_path):
subprocess.check_call([
deepspeech_path + '/generate_trie',
alphabet_path,
lm_path,
trie_path
])
else:
lm_path = lang_lm_path
trie_path = lang_trie_path
logging.debug('Loading acoustic model from "{}", alphabet from "{}", trie from "{}" and language model from "{}"...'
.format(output_graph_path, alphabet_path, trie_path, lm_path))
# Run VAD on the input file
logging.debug('Transcribing VAD segments...')
frames = read_frames_from_file(audio_path, model_format, args.audio_vad_frame_length)
segments = vad_split(frames,
model_format,
num_padding_frames=args.audio_vad_padding,
threshold=args.audio_vad_threshold,
aggressiveness=args.audio_vad_aggressiveness)
def pre_filter():
for i, segment in enumerate(segments):
segment_buffer, time_start, time_end = segment
time_length = time_end - time_start
if args.stt_min_duration and time_length < args.stt_min_duration:
logging.info('Fragment {}: Audio too short for STT'.format(i))
continue
if args.stt_max_duration and time_length > args.stt_max_duration:
logging.info('Fragment {}: Audio too long for STT'.format(i))
continue
yield (time_start, time_end, np.frombuffer(segment_buffer, dtype=np.int16))
samples = list(progress(pre_filter(), desc='VAD splitting'))
pool = multiprocessing.Pool(initializer=init_stt,
initargs=(output_graph_path, lm_path, trie_path),
processes=args.stt_workers)
transcripts = progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))
fragments = []
for time_start, time_end, segment_transcript in transcripts:
if segment_transcript is None:
continue
fragments.append({
'start': time_start,
'end': time_end,
'transcript': segment_transcript
})
logging.debug('Excluded {} empty transcripts'.format(len(transcripts) - len(fragments)))
logging.debug('Writing transcription log to file "{}"...'.format(tlog_path))
with open(tlog_path, 'w') as tlog_file:
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None))
if not path.isfile(tlog_path):
fail('Problem loading transcript from "{}"'.format(tlog_path))
to_align.append((tlog_path, script_path, aligned_path))
total_fragments = 0
dropped_fragments = 0
reasons = Counter()
index = 0
pool = multiprocessing.Pool(processes=args.align_workers)
for aligned_file, file_total_fragments, file_dropped_fragments, file_reasons in \
progress(pool.imap_unordered(align, to_align), desc='Aligning', total=len(to_align)):
if args.no_progress:
index += 1
logging.info('Aligned file {} of {} - wrote results to "{}"'.format(index, len(to_align), aligned_file))
total_fragments += file_total_fragments
dropped_fragments += file_dropped_fragments
reasons += file_reasons
logging.info('Aligned {} fragments'.format(total_fragments))
if total_fragments > 0 and dropped_fragments > 0:
logging.info('Dropped {} fragments {:0.2f}%:'.format(dropped_fragments,
dropped_fragments * 100.0 / total_fragments))
for key, number in reasons.most_common():
logging.info(' - {}: {}'.format(key, number))
def parse_args():
parser = argparse.ArgumentParser(description='Force align speech data with a transcript.')
parser.add_argument('--audio', type=str,
@ -395,10 +578,20 @@ def main():
help='Path to an alphabet file (overriding the one from --stt-model-dir)')
audio_group = parser.add_argument_group(title='Audio pre-processing options')
audio_group.add_argument('--audio-vad-aggressiveness', type=int, choices=range(4), required=False,
help='Determines how aggressive filtering out non-speech is (default: 3)')
audio_group.add_argument('--audio-vad-aggressiveness', type=int, choices=range(4), default=3,
help='Aggressiveness of voice activity detection in a frame (default: 3)')
audio_group.add_argument('--audio-vad-padding', type=int, default=10,
help='Number of padding audio frames in VAD ring-buffer')
audio_group.add_argument('--audio-vad-threshold', type=float, default=0.5,
help='VAD ring-buffer threshold for voiced frames '
'(e.g. 0.5 -> 50% of the ring-buffer frames have to be voiced '
'for triggering a split)')
audio_group.add_argument('--audio-vad-frame-length', choices=[10, 20, 30], default=30,
help='VAD audio frame length in ms (10, 20 or 30)')
stt_group = parser.add_argument_group(title='STT options')
stt_group.add_argument('--stt-model-rate', type=int, default=DEFAULT_RATE,
help='Supported sample rate of the acoustic model')
stt_group.add_argument('--stt-model-dir', required=False,
help='Path to a directory with output_graph, lm, trie and (optional) alphabet file ' +
'(default: "data/en"')
@ -467,8 +660,8 @@ def main():
output_group.add_argument('--output-pretty', action="store_true",
help='Writes indented JSON output"')
for short in named_numbers.keys():
long, atype, desc = named_numbers[short]
for short in NAMED_NUMBERS.keys():
long, atype, desc = NAMED_NUMBERS[short]
desc = (' - value range: ' + desc) if desc else ''
output_group.add_argument('--output-' + short.lower(), action="store_true",
help='Writes {} ({}) to output'.format(long, short))
@ -476,69 +669,14 @@ def main():
output_group.add_argument('--output-' + extreme.lower() + '-' + short.lower(), type=atype, required=False,
help='{}imum {} ({}) the STT transcript of the audio '
'has to have when compared with the original text{}'
.format(extreme, long, short, desc))
.format(extreme, long, short, desc))
args = parser.parse_args()
return parser.parse_args()
# Debug helpers
logging.basicConfig(stream=sys.stdout, level=args.loglevel if args.loglevel else 20)
def progress(iter, **kwargs):
return iter if args.no_progress else tqdm(iter, **kwargs)
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not path.isabs(spec_path):
spec_path = path.join(base_path, spec_path)
return spec_path
def exists(file_path):
if file_path is None:
return False
return os.path.isfile(file_path)
to_prepare = []
def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
if exists(aligned) and not args.force:
fail(prefix + 'Alignment file "{}" already existing - use --force to overwrite'.format(aligned))
if tlog is None:
if args.ignore_missing:
return
fail(prefix + 'Missing transcription log path')
if not exists(audio) and not exists(tlog):
if args.ignore_missing:
return
fail(prefix + 'Both audio file "{}" and transcription log "{}" are missing'.format(audio, tlog))
if not exists(script):
if args.ignore_missing:
return
fail(prefix + 'Missing script "{}"'.format(script))
to_prepare.append((audio, tlog, script, aligned))
if (args.audio or args.tlog) and args.script and args.aligned and not args.catalog:
enqueue_or_fail(args.audio, args.tlog, args.script, args.aligned)
elif args.catalog:
if not exists(args.catalog):
fail('Unable to load catalog file "{}"'.format(args.catalog))
catalog = path.abspath(args.catalog)
catalog_dir = path.dirname(catalog)
with open(catalog, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
for entry in progress(catalog_entries, desc='Reading catalog'):
enqueue_or_fail(resolve(catalog_dir, entry['audio']),
resolve(catalog_dir, entry['tlog']),
resolve(catalog_dir, entry['script']),
resolve(catalog_dir, entry['aligned']),
prefix='Problem loading catalog "{}" - '.format(catalog))
else:
fail('You have to either specify a combination of "--audio/--tlog,--script,--aligned" or "--catalog"')
logging.debug('Start')
if __name__ == '__main__':
args = parse_args()
model_dir = os.path.expanduser(args.stt_model_dir if args.stt_model_dir else 'models/en')
if args.alphabet is not None:
alphabet_path = args.alphabet
else:
@ -547,130 +685,5 @@ def main():
fail('Found no alphabet file')
logging.debug('Loading alphabet from "{}"...'.format(alphabet_path))
alphabet = Alphabet(alphabet_path)
to_align = []
output_graph_path = None
for audio, tlog, script, aligned in to_prepare:
if not exists(tlog):
if output_graph_path is None:
logging.debug('Looking for model files in "{}"...'.format(model_dir))
output_graph_path, lang_lm_path, lang_trie_path = wavTranscriber.resolve_models(model_dir)
kenlm_path = 'dependencies/kenlm/build/bin'
if not path.exists(kenlm_path):
kenlm_path = None
deepspeech_path = 'dependencies/deepspeech'
if not path.exists(deepspeech_path):
deepspeech_path = None
if kenlm_path and deepspeech_path and not args.stt_no_own_lm:
tc = read_script(script)
if not tc.clean_text.strip():
logging.error('Cleaned transcript is empty for {}'.format(path.basename(script)))
continue
clean_text_path = script + '.clean'
with open(clean_text_path, 'w') as clean_text_file:
clean_text_file.write(tc.clean_text)
arpa_path = script + '.arpa'
if not path.exists(arpa_path):
subprocess.check_call([
kenlm_path + '/lmplz',
'--discount_fallback',
'--text',
clean_text_path,
'--arpa',
arpa_path,
'--o',
'5'
])
lm_path = script + '.lm'
if not path.exists(lm_path):
subprocess.check_call([
kenlm_path + '/build_binary',
'-s',
arpa_path,
lm_path
])
trie_path = script + '.trie'
if not path.exists(trie_path):
subprocess.check_call([
deepspeech_path + '/generate_trie',
alphabet_path,
lm_path,
trie_path
])
else:
lm_path = lang_lm_path
trie_path = lang_trie_path
logging.debug('Loading acoustic model from "{}", alphabet from "{}", trie from "{}" and language model from "{}"...'
.format(output_graph_path, alphabet_path, trie_path, lm_path))
# Run VAD on the input file
logging.debug('Transcribing VAD segments...')
aggressiveness = int(args.audio_vad_aggressiveness) if args.audio_vad_aggressiveness else 3
segments, rate, audio_length = wavSplit.vad_segment_generator(audio, aggressiveness)
def pre_filter():
for i, segment in enumerate(segments):
segment_buffer, time_start, time_end = segment
time_length = time_end - time_start
if args.stt_min_duration and time_length < args.stt_min_duration:
logging.info('Fragment {}: Audio too short for STT'.format(i))
continue
if args.stt_max_duration and time_length > args.stt_max_duration:
logging.info('Fragment {}: Audio too long for STT'.format(i))
continue
yield (time_start, time_end, np.frombuffer(segment_buffer, dtype=np.int16))
samples = list(progress(pre_filter(), desc='VAD splitting'))
pool = multiprocessing.Pool(initializer=init_stt,
initargs=(output_graph_path, alphabet_path, lm_path, trie_path, rate),
processes=args.stt_workers)
transcripts = progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))
fragments = []
for time_start, time_end, segment_transcript in transcripts:
if segment_transcript is None:
continue
fragments.append({
'start': time_start,
'end': time_end,
'transcript': segment_transcript
})
logging.debug('Excluded {} empty transcripts'.format(len(transcripts) - len(fragments)))
logging.debug('Writing transcription log to file "{}"...'.format(tlog))
with open(tlog, 'w') as tlog_file:
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None))
if not path.isfile(tlog):
fail('Problem loading transcript from "{}"'.format(tlog))
to_align.append((tlog, script, aligned))
total_fragments = 0
dropped_fragments = 0
reasons = Counter()
index = 0
pool = multiprocessing.Pool(initializer=init_align, initargs=(args, alphabet), processes=args.align_workers)
for aligned_file, file_total_fragments, file_dropped_fragments, file_reasons in \
progress(pool.imap_unordered(align, to_align), desc='Aligning', total=len(to_align)):
if args.no_progress:
index += 1
logging.info('Aligned file {} of {} - wrote results to "{}"'.format(index, len(to_align), aligned_file))
total_fragments += file_total_fragments
dropped_fragments += file_dropped_fragments
reasons += file_reasons
logging.info('Aligned {} fragments'.format(total_fragments))
if total_fragments > 0 and dropped_fragments > 0:
logging.info('Dropped {} fragments {:0.2f}%:'.format(dropped_fragments,
dropped_fragments * 100.0 / total_fragments))
for key, number in reasons.most_common():
logging.info(' - {}: {}'.format(key, number))
if __name__ == '__main__':
model_format = (args.stt_model_rate, 1, 2)
main()

Просмотреть файл

@ -1,25 +1,110 @@
import os
import io
import sox
import wave
import opuslib
import tempfile
import collections
import numpy as np
from webrtcvad import Vad
from utils import LimitingPool
DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1
DEFAULT_WIDTH = 2
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
AUDIO_TYPE_NP = 'np'
AUDIO_TYPE_PCM = 'pcm'
AUDIO_FILE_PREFIX = 'audio/'
AUDIO_TYPE_WAV = AUDIO_FILE_PREFIX + 'wav'
AUDIO_TYPE_OPUS = AUDIO_FILE_PREFIX + 'opus'
LOADABLE_FILE_FORMATS = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
def get_audio_format(wav_file):
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
OPUS_PCM_LEN_SIZE = 4
OPUS_RATE_SIZE = 4
OPUS_CHANNELS_SIZE = 1
OPUS_WIDTH_SIZE = 1
OPUS_CHUNK_LEN_SIZE = 2
NP_TYPE_LOOKUP = [None, np.int8, np.int16, None, np.int32]
UNSUPPORTED_TYPE = 'Unsupported audio type: {}'
def set_audio_format(wav_file, audio_format=DEFAULT_FORMAT):
class Sample:
def __init__(self, audio_type, raw_data, audio_format=None):
self.audio_type = audio_type
self.audio_format = audio_format
if audio_type in LOADABLE_FILE_FORMATS:
self.audio = io.BytesIO(raw_data)
self.duration = read_duration(audio_type, self.audio)
else:
self.audio = raw_data
if self.audio_format is None:
raise ValueError('For audio type "{}" parameter "audio_format" is mandatory')
if audio_type == AUDIO_TYPE_PCM:
self.duration = get_pcm_duration(len(self.audio), self.audio_format)
elif audio_type == AUDIO_TYPE_NP:
self.duration = get_np_duration(len(self.audio), self.audio_format)
else:
raise ValueError(UNSUPPORTED_TYPE.format(self.audio_type))
def convert(self, new_audio_type):
if self.audio_type == new_audio_type:
return
if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in LOADABLE_FILE_FORMATS:
self.audio_format, audio = read_audio(self.audio_type, self.audio)
self.audio.close()
self.audio = audio
elif new_audio_type == AUDIO_TYPE_NP:
self.convert(AUDIO_TYPE_PCM)
self.audio = pcm_to_np(self.audio_format, self.audio)
elif new_audio_type in LOADABLE_FILE_FORMATS:
self.convert(AUDIO_TYPE_PCM)
audio_bytes = io.BytesIO()
write_audio(new_audio_type, audio_bytes, self.audio_format, self.audio)
audio_bytes.seek(0)
self.audio = audio_bytes
else:
raise RuntimeError('Audio conversion from "{}" to "{}" not supported'
.format(self.audio_type, new_audio_type))
self.audio_type = new_audio_type
def convert_samples(samples, audio_type=AUDIO_TYPE_PCM, processes=None):
def convert_sample(sample):
sample.convert(audio_type)
return sample
with LimitingPool(processes=processes) as pool:
for current_sample in pool.map(convert_sample, samples):
yield current_sample
def write_audio_format_to_wav_file(wav_file, audio_format=DEFAULT_FORMAT):
rate, channels, width = audio_format
wav_file.setframerate(rate)
wav_file.setnchannels(channels)
wav_file.setsampwidth(width)
def read_audio_format_from_wav_file(wav_file):
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
def get_num_samples(pcm_len, audio_format=DEFAULT_FORMAT):
_, channels, width = audio_format
return pcm_len // (channels * width)
def get_pcm_duration(pcm_len, audio_format=DEFAULT_FORMAT):
return get_num_samples(pcm_len, audio_format) / audio_format[0]
def get_np_duration(np_len, audio_format=DEFAULT_FORMAT):
return np_len / audio_format[0]
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
sample_rate, channels, width = audio_format
transformer = sox.Transformer()
@ -30,7 +115,7 @@ def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=D
def ensure_wav_with_format(src_audio_path, audio_format=DEFAULT_FORMAT):
if src_audio_path.endswith('.wav'):
with wave.open(src_audio_path, 'r') as src_audio_file:
if get_audio_format(src_audio_file) == audio_format:
if read_audio_format_from_wav_file(src_audio_file) == audio_format:
return src_audio_path, False
fd, tmp_file_path = tempfile.mkstemp(suffix='.wav')
os.close(fd)
@ -43,3 +128,215 @@ def extract_audio(audio_file, start, end):
rate = audio_file.getframerate()
audio_file.setpos(int(start * rate))
return audio_file.readframes(int((end - start) * rate))
class AudioFile:
def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT):
self.audio_path = audio_path
self.audio_format = audio_format
self.as_path = as_path
self.open_file = None
self.tmp_file_path = None
def __enter__(self):
if self.audio_path.endswith('.wav'):
self.open_file = wave.open(self.audio_path, 'r')
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
if self.as_path:
self.open_file.close()
return self.audio_path
return self.open_file
self.open_file.close()
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
if self.as_path:
return self.tmp_file_path
self.open_file = wave.open(self.tmp_file_path, 'r')
return self.open_file
def __exit__(self, *args):
if not self.as_path:
self.open_file.close()
if self.tmp_file_path is not None:
os.remove(self.tmp_file_path)
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
audio_format = read_audio_format_from_wav_file(wav_file)
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
while True:
try:
data = wav_file.readframes(frame_size)
if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms:
break
yield data
except EOFError:
break
def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False):
with AudioFile(audio_path, audio_format=audio_format) as wav_file:
for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder):
yield frame
def vad_split(audio_frames,
audio_format=DEFAULT_FORMAT,
num_padding_frames=10,
threshold=0.5,
aggressiveness=3):
sample_rate, channels, width = audio_format
if channels != 1:
raise ValueError('VAD-splitting requires mono samples')
if width != 2:
raise ValueError('VAD-splitting requires 16 bit samples')
if sample_rate not in [8000, 16000, 32000, 48000]:
raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000')
if aggressiveness not in [0, 1, 2, 3]:
raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3')
ring_buffer = collections.deque(maxlen=num_padding_frames)
triggered = False
vad = Vad(int(aggressiveness))
voiced_frames = []
frame_duration_ms = 0
frame_index = 0
for frame_index, frame in enumerate(audio_frames):
frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000
if int(frame_duration_ms) not in [10, 20, 30]:
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
is_speech = vad.is_speech(frame, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
if num_voiced > threshold * ring_buffer.maxlen:
triggered = True
for f, s in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False
yield b''.join(voiced_frames), \
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
frame_duration_ms * frame_index
ring_buffer.clear()
voiced_frames = []
if len(voiced_frames) > 0:
yield b''.join(voiced_frames), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * (frame_index + 1)
def pack_number(n, num_bytes):
return n.to_bytes(num_bytes, 'big', signed=False)
def unpack_number(data):
return int.from_bytes(data, 'big', signed=False)
def get_opus_frame_size(rate):
return 60 * rate // 1000
def write_opus(opus_file, audio_format, audio_data):
rate, channels, width = audio_format
frame_size = get_opus_frame_size(rate)
encoder = opuslib.Encoder(rate, channels, opuslib.APPLICATION_AUDIO)
chunk_size = frame_size * channels * width
opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE))
opus_file.write(pack_number(rate, OPUS_RATE_SIZE))
opus_file.write(pack_number(channels, OPUS_CHANNELS_SIZE))
opus_file.write(pack_number(width, OPUS_WIDTH_SIZE))
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i + chunk_size]
encoded = encoder.encode(chunk, frame_size)
opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE))
opus_file.write(encoded)
def read_opus_header(opus_file):
opus_file.seek(0)
pcm_len = unpack_number(opus_file.read(OPUS_PCM_LEN_SIZE))
rate = unpack_number(opus_file.read(OPUS_RATE_SIZE))
channels = unpack_number(opus_file.read(OPUS_CHANNELS_SIZE))
width = unpack_number(opus_file.read(OPUS_WIDTH_SIZE))
return pcm_len, (rate, channels, width)
def read_opus(opus_file):
pcm_len, audio_format = read_opus_header(opus_file)
rate, channels, _ = audio_format
frame_size = get_opus_frame_size(rate)
decoder = opuslib.Decoder(rate, channels)
audio_data = bytearray()
while len(audio_data) < pcm_len:
chunk_len = unpack_number(opus_file.read(OPUS_CHUNK_LEN_SIZE))
chunk = opus_file.read(chunk_len)
decoded = decoder.decode(chunk, frame_size)
audio_data.extend(decoded)
audio_data = audio_data[:pcm_len]
return audio_format, audio_data
def write_wav(wav_file, audio_format, pcm_data):
with wave.open(wav_file, 'wb') as wav_file_writer:
write_audio_format_to_wav_file(wav_file_writer, audio_format)
wav_file_writer.writeframes(pcm_data)
def read_wav(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
audio_format = read_audio_format_from_wav_file(wav_file_reader)
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
return audio_format, pcm_data
def read_audio(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus(audio_file)
raise ValueError(UNSUPPORTED_TYPE.format(audio_type))
def write_audio(audio_type, audio_file, audio_format, pcm_data):
if audio_type == AUDIO_TYPE_WAV:
return write_wav(audio_file, audio_format, pcm_data)
if audio_type == AUDIO_TYPE_OPUS:
return write_opus(audio_file, audio_format, pcm_data)
raise ValueError(UNSUPPORTED_TYPE.format(audio_type))
def read_wav_duration(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
return wav_file_reader.getnframes() / wav_file_reader.getframerate()
def read_opus_duration(opus_file):
pcm_len, audio_format = read_opus_header(opus_file)
return get_pcm_duration(pcm_len, audio_format)
def read_duration(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav_duration(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus_duration(audio_file)
raise ValueError(UNSUPPORTED_TYPE.format(audio_type))
def pcm_to_np(audio_format, pcm_data):
_, channels, width = audio_format
if width < 1 or width > 4 or width == 3:
raise ValueError('Unsupported sample width: {}'.format(width))
dtype = NP_TYPE_LOOKUP[width]
samples = np.frombuffer(pcm_data, dtype=dtype)
samples = samples[::channels] # limited to mono for now
samples = samples.astype(np.float32) / np.iinfo(dtype).max
return np.expand_dims(samples, axis=1)

Просмотреть файл

@ -16,7 +16,7 @@ from tqdm import tqdm
from datetime import timedelta
from collections import Counter
from multiprocessing import Pool
from audio import DEFAULT_FORMAT, ensure_wav_with_format, extract_audio, set_audio_format
from audio import DEFAULT_FORMAT, ensure_wav_with_format, extract_audio, write_audio_format_to_wav_file
audio_format = DEFAULT_FORMAT
unknown = '<unknown>'
@ -430,7 +430,7 @@ def main(args):
sample_path = '{}/sample-{:010d}.wav'.format(fragment['list-name'], len(group_list))
with TargetFile(sample_path, "wb") as base_wav_file:
with wave.open(base_wav_file, 'wb') as wav_file:
set_audio_format(wav_file)
write_audio_format_to_wav_file(wav_file)
wav_file.writeframes(audio_segment)
file_size = base_wav_file.tell()
group_list.append((sample_path, file_size, fragment))

Просмотреть файл

@ -1,3 +1,6 @@
from multiprocessing.dummy import Pool as ThreadPool
def circulate(items, center=None):
count = len(list(items))
if count > 0:
@ -57,3 +60,30 @@ def greedy_minimum_search(a, b, compute, result_a=None, result_b=None):
return greedy_minimum_search(a, c, compute, result_a=result_a)
else:
return greedy_minimum_search(c, b, compute, result_b=result_b)
class LimitingPool:
def __init__(self, processes=None, limit_factor=2, sleeping_for=0.1):
self.processes = os.cpu_count() if processes is None else processes
self.pool = ThreadPool(processes=processes)
self.sleeping_for = sleeping_for
self.max_ahead = self.processes * limit_factor
self.processed = 0
def __enter__(self):
return self
def limit(self, it):
for obj in it:
while self.processed >= self.max_ahead:
time.sleep(self.sleeping_for)
self.processed += 1
yield obj
def map(self, fun, it):
for obj in self.pool.imap(fun, self.limit(it)):
self.processed -= 1
yield obj
def __exit__(self, exc_type, exc_value, traceback):
self.pool.close()

Просмотреть файл

@ -1,128 +0,0 @@
import collections
from webrtcvad import Vad
from pydub import AudioSegment
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, threshold, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
the voiced audio.
Uses a padded, sliding window algorithm over the audio frames.
When more than 90% of the frames in the window are voiced (as
reported by the VAD), the collector triggers and begins yielding
audio frames. Then the collector waits until 90% of the frames in
the window are unvoiced to detrigger.
The window is padded at the front and back to provide a small
amount of silence or the beginnings/endings of speech around the
voiced frames.
Arguments:
sample_rate - The audio sample rate, in Hz.
frame_duration_ms - The frame duration in milliseconds.
padding_duration_ms - The amount to pad the window, in milliseconds.
vad - An instance of webrtcvad.Vad.
frames - a source of audio frames (sequence or generator).
Returns: A generator that yields PCM audio data.
"""
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
# We use a deque for our sliding window/ring buffer.
ring_buffer = collections.deque(maxlen=num_padding_frames)
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
# NOTTRIGGERED state.
triggered = False
voiced_frames = []
for frame_index, frame in enumerate(frames):
is_speech = vad.is_speech(frame.bytes, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > threshold * ring_buffer.maxlen:
triggered = True
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, s in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False
yield b''.join([f.bytes for f in voiced_frames]), \
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
frame_duration_ms * frame_index
ring_buffer.clear()
voiced_frames = []
if triggered:
pass
# If we have any leftover voiced audio when we run out of input,
# yield it.
if voiced_frames:
yield b''.join([f.bytes for f in voiced_frames]), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * (frame_index + 1)
def vad_segment_generator(audio_file, aggressiveness):
"""
Generate VAD segments. Filters out non-voiced audio frames.
:param audio_file: Input audio file to run VAD on.
:param aggressiveness: How aggressive filtering out non-speech is (between 0 and 3)
:return: Returns tuple of
segments: a bytearray of multiple smaller audio frames
(The longer audio split into multiple smaller one's)
sample_rate: Sample rate of the input audio file
audio_length: Duration of the input audio file
"""
audio = (AudioSegment.from_file(audio_file)
.set_channels(1)
.set_frame_rate(16000))
vad = Vad(int(aggressiveness))
frames = frame_generator(30, audio.raw_data, audio.frame_rate)
segments = vad_collector(audio.frame_rate, 30, 300, 0.5, vad, frames)
return segments, audio.frame_rate, audio.duration_seconds * 1000

Просмотреть файл

@ -1,51 +0,0 @@
import glob
from deepspeech import Model
def load_model(models, alphabet, lm, trie):
"""
Load the pre-trained model into the memory
:param models: Output Graph Protocol Buffer file
:param alphabet: Alphabet.txt file
:param lm: Language model file
:param trie: Trie file
:return: tuple (DeepSpeech object, Model Load Time, LM Load Time)
"""
N_FEATURES = 26
N_CONTEXT = 9
BEAM_WIDTH = 500
#LM_ALPHA = 0.75
#LM_BETA = 1.85
LM_ALPHA = 1
LM_BETA = 1.85
ds = Model(models, BEAM_WIDTH)
ds.enableDecoderWithLM(lm, trie, LM_ALPHA, LM_BETA)
return ds
def stt(ds, audio, fs):
"""
Run Inference on input audio file
:param ds: DeepSpeech object
:param audio: Input audio for running inference on
:param fs: Sample rate of the input audio file
:return: tuple (Inference result text, Inference time)
"""
audio_length = len(audio) * (1 / 16000)
# Run DeepSpeech
output = ds.stt(audio)
return output
def resolve_models(dir_name):
"""
Resolve directory path for the models and fetch each of them.
:param dir_name: Path to the directory containing pre-trained models
:return: tuple containing each of the model files (pb, alphabet, lm and trie)
"""
pb = glob.glob(dir_name + "/*.pb")[0]
lm = glob.glob(dir_name + "/lm.binary")[0]
trie = glob.glob(dir_name + "/trie")[0]
return pb, lm, trie