зеркало из https://github.com/mozilla/DSAlign.git
This commit is contained in:
Родитель
407a6d26d1
Коммит
f5d5e7a0dc
|
@ -5,9 +5,6 @@ import argparse
|
|||
import numpy as np
|
||||
import wavTranscriber
|
||||
|
||||
# Debug helpers
|
||||
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = argparse.ArgumentParser(description='Transcribe long audio files using webRTC VAD or use the streaming interface')
|
||||
|
@ -17,15 +14,17 @@ def main(args):
|
|||
help='Determines how aggressive filtering out non-speech is. (Interger between 0-3)')
|
||||
parser.add_argument('--model', required=False,
|
||||
help='Path to directory that contains all model files (output_graph, lm, trie and alphabet)')
|
||||
parser.add_argument('--loglevel', type=int, required=False,
|
||||
help='Log level (between 0 and 50) - default: 20')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Debug helpers
|
||||
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
|
||||
|
||||
# Loading model
|
||||
model_dir = os.path.expanduser(args.model if args.model else 'models/en')
|
||||
output_graph, alphabet, lm, trie = wavTranscriber.resolve_models(model_dir)
|
||||
model = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
|
||||
|
||||
title_names = ['Filename', 'Duration(s)', 'Inference Time(s)', 'Model Load Time(s)', 'LM Load Time(s)']
|
||||
print("\n%-30s %-20s %-20s %-20s %s" % (title_names[0], title_names[1], title_names[2], title_names[3], title_names[4]))
|
||||
model, _, _ = wavTranscriber.load_model(output_graph, alphabet, lm, trie)
|
||||
|
||||
inference_time = 0.0
|
||||
|
||||
|
@ -37,26 +36,17 @@ def main(args):
|
|||
logging.debug("Saving Transcript @: %s" % wave_file.rstrip(".wav") + ".txt")
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
# Run deepspeech on the chunk that just completed VAD
|
||||
# Run DeepSpeech on the chunk that just completed VAD
|
||||
logging.debug("Processing chunk %002d" % (i,))
|
||||
audio = np.frombuffer(segment, dtype=np.int16)
|
||||
output = wavTranscriber.stt(model[0], audio, sample_rate)
|
||||
inference_time += output[1]
|
||||
logging.debug("Transcript: %s" % output[0])
|
||||
transcript, segment_inference_time = wavTranscriber.stt(model, audio, sample_rate)
|
||||
inference_time += segment_inference_time
|
||||
logging.info("Transcript: %s" % transcript)
|
||||
|
||||
f.write(output[0] + " ")
|
||||
f.write(transcript + " ")
|
||||
|
||||
# Summary of the files processed
|
||||
f.close()
|
||||
|
||||
# Extract filename from the full file path
|
||||
filename, ext = os.path.split(os.path.basename(wave_file))
|
||||
logging.debug("************************************************************************************************************")
|
||||
logging.debug("%-30s %-20s %-20s %-20s %s" % (title_names[0], title_names[1], title_names[2], title_names[3], title_names[4]))
|
||||
logging.debug("%-30s %-20.3f %-20.3f %-20.3f %-0.3f" % (filename + ext, audio_length, inference_time, model_retval[1], model_retval[2]))
|
||||
logging.debug("************************************************************************************************************")
|
||||
print("%-30s %-20.3f %-20.3f %-20.3f %-0.3f" % (filename + ext, audio_length, inference_time, model_retval[1], model_retval[2]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from six.moves import range
|
||||
|
||||
class Alphabet(object):
|
||||
def __init__(self, config_file):
|
||||
self._config_file = config_file
|
||||
self._label_to_str = []
|
||||
self._str_to_label = {}
|
||||
self._size = 0
|
||||
with codecs.open(config_file, 'r', 'utf-8') as fin:
|
||||
for line in fin:
|
||||
if line[0:2] == '\\#':
|
||||
line = '#\n'
|
||||
elif line[0] == '#':
|
||||
continue
|
||||
self._label_to_str += line[:-1] # remove the line ending
|
||||
self._str_to_label[line[:-1]] = self._size
|
||||
self._size += 1
|
||||
|
||||
def string_from_label(self, label):
|
||||
return self._label_to_str[label]
|
||||
|
||||
def label_from_string(self, string):
|
||||
try:
|
||||
return self._str_to_label[string]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
|
||||
).with_traceback(e.__traceback__)
|
||||
|
||||
def decode(self, labels):
|
||||
res = ''
|
||||
for label in labels:
|
||||
res += self.string_from_label(label)
|
||||
return res
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def config_file(self):
|
||||
return self._config_file
|
||||
|
||||
|
||||
def text_to_char_array(original, alphabet):
|
||||
"""
|
||||
Given a Python string ``original``, remove unsupported characters, map characters
|
||||
to integers and return a numpy array representing the processed string.
|
||||
"""
|
||||
return np.asarray([alphabet.label_from_string(c) for c in original])
|
||||
|
||||
|
||||
# Validate and normalize transcriptions. Returns a cleaned version of the label
|
||||
# or None if it's invalid.
|
||||
def validate_label(label):
|
||||
# For now we can only handle [a-z ']
|
||||
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
|
||||
return None
|
||||
|
||||
label = label.replace("-", "")
|
||||
label = label.replace("_", "")
|
||||
label = label.replace(".", "")
|
||||
label = label.replace(",", "")
|
||||
label = label.replace("?", "")
|
||||
label = label.replace("\"", "")
|
||||
label = label.strip()
|
||||
label = label.lower()
|
||||
|
||||
return label if label else None
|
||||
|
||||
|
||||
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
||||
|
||||
# This is a straightforward implementation of a well-known algorithm, and thus
|
||||
# probably shouldn't be covered by copyright to begin with. But in case it is,
|
||||
# the author (Magnus Lie Hetland) has, to the extent possible under law,
|
||||
# dedicated all copyright and related and neighboring rights to this software
|
||||
# to the public domain worldwide, by distributing it under the CC0 license,
|
||||
# version 1.0. This software is distributed without any warranty. For more
|
||||
# information, see <http://creativecommons.org/publicdomain/zero/1.0>
|
||||
|
||||
def levenshtein(a, b):
|
||||
"""
|
||||
Calculates the Levenshtein distance between a and b.
|
||||
"""
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
# Make sure n <= m, to use O(min(n,m)) space
|
||||
a, b = b, a
|
||||
n, m = m, n
|
||||
|
||||
current = list(range(n+1))
|
||||
for i in range(1, m+1):
|
||||
previous, current = current, [i]+[0]*n
|
||||
for j in range(1, n+1):
|
||||
add, delete = previous[j]+1, current[j-1]+1
|
||||
change = previous[j-1]
|
||||
if a[j-1] != b[i-1]:
|
||||
change = change + 1
|
||||
current[j] = min(add, delete, change)
|
||||
|
||||
return current[n]
|
|
@ -60,7 +60,7 @@ def frame_generator(frame_duration_ms, audio, sample_rate):
|
|||
|
||||
|
||||
def vad_collector(sample_rate, frame_duration_ms,
|
||||
padding_duration_ms, vad, frames):
|
||||
padding_duration_ms, threshold, vad, frames):
|
||||
"""Filters out non-voiced audio frames.
|
||||
|
||||
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||
|
@ -103,7 +103,7 @@ def vad_collector(sample_rate, frame_duration_ms,
|
|||
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||
# the ring buffer are voiced frames, then enter the
|
||||
# TRIGGERED state.
|
||||
if num_voiced > 0.9 * ring_buffer.maxlen:
|
||||
if num_voiced > threshold * ring_buffer.maxlen:
|
||||
triggered = True
|
||||
# We want to yield all the audio we see from now until
|
||||
# we are NOTTRIGGERED, but we have to start with the
|
||||
|
@ -120,7 +120,7 @@ def vad_collector(sample_rate, frame_duration_ms,
|
|||
# If more than 90% of the frames in the ring buffer are
|
||||
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||
# audio we've collected.
|
||||
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
||||
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||
triggered = False
|
||||
yield b''.join([f.bytes for f in voiced_frames])
|
||||
ring_buffer.clear()
|
||||
|
|
|
@ -5,17 +5,16 @@ import wavSplit
|
|||
from deepspeech import Model
|
||||
from timeit import default_timer as timer
|
||||
|
||||
'''
|
||||
Load the pre-trained model into the memory
|
||||
@param models: Output Grapgh Protocol Buffer file
|
||||
@param alphabet: Alphabet.txt file
|
||||
@param lm: Language model file
|
||||
@param trie: Trie file
|
||||
|
||||
@Retval
|
||||
Returns a list [DeepSpeech Object, Model Load Time, LM Load Time]
|
||||
'''
|
||||
def load_model(models, alphabet, lm, trie):
|
||||
"""
|
||||
Load the pre-trained model into the memory
|
||||
:param models: Output Graph Protocol Buffer file
|
||||
:param alphabet: Alphabet.txt file
|
||||
:param lm: Language model file
|
||||
:param trie: Trie file
|
||||
:return: tuple (DeepSpeech object, Model Load Time, LM Load Time)
|
||||
"""
|
||||
N_FEATURES = 26
|
||||
N_CONTEXT = 9
|
||||
BEAM_WIDTH = 500
|
||||
|
@ -32,72 +31,65 @@ def load_model(models, alphabet, lm, trie):
|
|||
lm_load_end = timer() - lm_load_start
|
||||
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))
|
||||
|
||||
return [ds, model_load_end, lm_load_end]
|
||||
return ds, model_load_end, lm_load_end
|
||||
|
||||
'''
|
||||
Run Inference on input audio file
|
||||
@param ds: Deepspeech object
|
||||
@param audio: Input audio for running inference on
|
||||
@param fs: Sample rate of the input audio file
|
||||
|
||||
@Retval:
|
||||
Returns a list [Inference, Inference Time, Audio Length]
|
||||
|
||||
'''
|
||||
def stt(ds, audio, fs):
|
||||
inference_time = 0.0
|
||||
"""
|
||||
Run Inference on input audio file
|
||||
:param ds: DeepSpeech object
|
||||
:param audio: Input audio for running inference on
|
||||
:param fs: Sample rate of the input audio file
|
||||
:return: tuple (Inference result text, Inference time)
|
||||
"""
|
||||
audio_length = len(audio) * (1 / 16000)
|
||||
|
||||
# Run Deepspeech
|
||||
# Run DeepSpeech
|
||||
logging.debug('Running inference...')
|
||||
inference_start = timer()
|
||||
output = ds.stt(audio, fs)
|
||||
inference_end = timer() - inference_start
|
||||
inference_time += inference_end
|
||||
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length))
|
||||
inference_time = timer() - inference_start
|
||||
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_time, audio_length))
|
||||
return output, inference_time
|
||||
|
||||
return [output, inference_time]
|
||||
|
||||
'''
|
||||
Resolve directory path for the models and fetch each of them.
|
||||
@param dirName: Path to the directory containing pre-trained models
|
||||
|
||||
@Retval:
|
||||
Retunns a tuple containing each of the model files (pb, alphabet, lm and trie)
|
||||
'''
|
||||
def resolve_models(dirName):
|
||||
pb = glob.glob(dirName + "/*.pb")[0]
|
||||
def resolve_models(dir_name):
|
||||
"""
|
||||
Resolve directory path for the models and fetch each of them.
|
||||
:param dir_name: Path to the directory containing pre-trained models
|
||||
:return: tuple containing each of the model files (pb, alphabet, lm and trie)
|
||||
"""
|
||||
pb = glob.glob(dir_name + "/*.pb")[0]
|
||||
logging.debug("Found Model: %s" % pb)
|
||||
|
||||
alphabet = glob.glob(dirName + "/alphabet.txt")[0]
|
||||
alphabet = glob.glob(dir_name + "/alphabet.txt")[0]
|
||||
logging.debug("Found Alphabet: %s" % alphabet)
|
||||
|
||||
lm = glob.glob(dirName + "/lm.binary")[0]
|
||||
trie = glob.glob(dirName + "/trie")[0]
|
||||
lm = glob.glob(dir_name + "/lm.binary")[0]
|
||||
trie = glob.glob(dir_name + "/trie")[0]
|
||||
logging.debug("Found Language Model: %s" % lm)
|
||||
logging.debug("Found Trie: %s" % trie)
|
||||
|
||||
return pb, alphabet, lm, trie
|
||||
|
||||
'''
|
||||
Generate VAD segments. Filters out non-voiced audio frames.
|
||||
@param waveFile: Input wav file to run VAD on.0
|
||||
|
||||
@Retval:
|
||||
Returns tuple of
|
||||
segments: a bytearray of multiple smaller audio frames
|
||||
(The longer audio split into mutiple smaller one's)
|
||||
sample_rate: Sample rate of the input audio file
|
||||
audio_length: Duraton of the input audio file
|
||||
|
||||
'''
|
||||
def vad_segment_generator(wavFile, aggressiveness):
|
||||
logging.debug("Caught the wav file @: %s" % (wavFile))
|
||||
audio, sample_rate, audio_length = wavSplit.read_wave(wavFile)
|
||||
def vad_segment_generator(wav_file, aggressiveness):
|
||||
"""
|
||||
Generate VAD segments. Filters out non-voiced audio frames.
|
||||
:param wav_file: Input wav file to run VAD on.0
|
||||
:param aggressiveness: How aggressive filtering out non-speech is (between 0 and 3)
|
||||
:return: Returns tuple of
|
||||
segments: a bytearray of multiple smaller audio frames
|
||||
(The longer audio split into multiple smaller one's)
|
||||
sample_rate: Sample rate of the input audio file
|
||||
audio_length: Duration of the input audio file
|
||||
"""
|
||||
logging.debug("Caught the wav file @: %s" % wav_file)
|
||||
audio, sample_rate, audio_length = wavSplit.read_wave(wav_file)
|
||||
assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!"
|
||||
vad = webrtcvad.Vad(int(aggressiveness))
|
||||
frames = wavSplit.frame_generator(30, audio, sample_rate)
|
||||
frames = list(frames)
|
||||
segments = wavSplit.vad_collector(sample_rate, 30, 300, vad, frames)
|
||||
segments = wavSplit.vad_collector(sample_rate, 30, 300, 0.5, vad, frames)
|
||||
|
||||
return segments, sample_rate, audio_length
|
||||
|
|
Загрузка…
Ссылка в новой задаче