Merge pull request #3493 from mozilla/add-ogg-opus-training-support

Add ogg opus training support
This commit is contained in:
Reuben Morais 2021-01-20 17:53:10 +00:00 коммит произвёл GitHub
Родитель ad0f7d2ab7 b2feb04763
Коммит 80b5fe10df
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 279 добавлений и 128 удалений

Просмотреть файл

@ -64,7 +64,7 @@ ENV TF_CUDA_PATHS "/usr,/usr/local/cuda-10.1,/usr/lib/x86_64-linux-gnu/"
ENV TF_CUDA_VERSION 10.1
ENV TF_CUDNN_VERSION 7.6
ENV TF_CUDA_COMPUTE_CAPABILITIES 6.0
ENV TF_NCCL_VERSION 2.7
ENV TF_NCCL_VERSION 2.8
# Common Environment Setup
ENV TF_BUILD_CONTAINER_TYPE GPU

Просмотреть файл

@ -4,6 +4,7 @@ Tool for comparing two wav samples
"""
import sys
import argparse
import numpy as np
from deepspeech_training.util.audio import AUDIO_TYPE_NP, mean_dbfs
from deepspeech_training.util.sample_collections import load_sample
@ -19,11 +20,15 @@ def compare_samples():
sample2 = load_sample(CLI_ARGS.sample2).unpack()
if sample1.audio_format != sample2.audio_format:
fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format))
if sample1.duration != sample2.duration:
if abs(sample1.duration - sample2.duration) > 0.001:
fail('Samples differ on: duration ({} and {})'.format(sample1.duration, sample2.duration))
sample1.change_audio_type(AUDIO_TYPE_NP)
sample2.change_audio_type(AUDIO_TYPE_NP)
audio_diff = sample1.audio - sample2.audio
samples = [sample1, sample2]
largest = np.argmax([sample1.audio.shape[0], sample2.audio.shape[0]])
smallest = (largest + 1) % 2
samples[largest].audio = samples[largest].audio[:len(samples[smallest].audio)]
audio_diff = samples[largest].audio - samples[smallest].audio
diff_dbfs = mean_dbfs(audio_diff)
differ_msg = 'Samples differ on: sample data ({:0.2f} dB difference) '.format(diff_dbfs)
equal_msg = 'Samples are considered equal ({:0.2f} dB difference)'.format(diff_dbfs)

Просмотреть файл

@ -9,14 +9,14 @@ import sys
import random
import argparse
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
from deepspeech_training.util.audio import get_loadable_audio_type_from_extension, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source
from deepspeech_training.util.augmentations import parse_augmentations, apply_sample_augmentations, SampleAugmentation
def get_samples_in_play_order():
ext = os.path.splitext(CLI_ARGS.source)[1].lower()
if ext in LOADABLE_AUDIO_EXTENSIONS:
if get_loadable_audio_type_from_extension(ext):
samples = SampleList([(CLI_ARGS.source, 0)], labeled=False)
else:
samples = samples_from_source(CLI_ARGS.source, buffering=0)

Просмотреть файл

@ -50,22 +50,21 @@ def main():
version = fin.read().strip()
install_requires_base = [
'numpy',
'progressbar2',
'six',
'pyxdg',
'attrdict',
'absl-py',
'semver',
'opuslib == 2.0.0',
'optuna',
'sox',
'attrdict',
'bs4',
'numpy',
'optuna',
'opuslib == 2.0.0',
'pandas',
'progressbar2',
'pyogg >= 0.6.14a1',
'pyxdg',
'resampy >= 0.2.2',
'requests',
'numba == 0.47.0', # ships py3.5 wheel
'llvmlite == 0.31.0', # for numba==0.47.0
'librosa',
'semver',
'six',
'sox',
'soundfile',
]

Просмотреть файл

@ -8,5 +8,5 @@ build:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-cpp_tflite_basic-ds-tests.sh 8k"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 TFLite C++ tests (16kHz)"
description: "Testing DeepSpeech C++ for Linux/AMD64, TFLite, optimized version (16kHz)"
name: "DeepSpeech Linux AMD64 TFLite C++ tests (8kHz)"
description: "Testing DeepSpeech C++ for Linux/AMD64, TFLite, optimized version (8kHz)"

Просмотреть файл

@ -1,13 +0,0 @@
build:
template_file: test-linux-opt-base.tyml
dependencies:
- "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_xenial.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-extra-tests.sh 3.5.8:m 16k"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 CPU 8kHz all training features Py3.7"
description: "Training (all features) a DeepSpeech LDC93S1 model for Linux/AMD64 8kHz Python 3.7, CPU only, optimized version"

Просмотреть файл

@ -9,5 +9,5 @@ build:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-extra-tests.sh 3.6.10:m 16k"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 CPU 8kHz all training features Py3.7"
description: "Training (all features) a DeepSpeech LDC93S1 model for Linux/AMD64 8kHz Python 3.7, CPU only, optimized version"
name: "DeepSpeech Linux AMD64 CPU 8kHz all training features Py3.6"
description: "Training (all features) a DeepSpeech LDC93S1 model for Linux/AMD64 8kHz Python 3.6, CPU only, optimized version"

Просмотреть файл

@ -9,5 +9,5 @@ build:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-extra-tests.sh 3.6.10:m 8k"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 CPU 8kHz all training features Py3.7"
description: "Training (all features) a DeepSpeech LDC93S1 model for Linux/AMD64 8kHz Python 3.7, CPU only, optimized version"
name: "DeepSpeech Linux AMD64 CPU 8kHz all training features Py3.6"
description: "Training (all features) a DeepSpeech LDC93S1 model for Linux/AMD64 8kHz Python 3.6, CPU only, optimized version"

Просмотреть файл

@ -1,16 +0,0 @@
build:
template_file: test-linux-opt-tag-base.tyml
dependencies:
- "scriptworker-task-pypi"
allowed:
- "tag"
ref_match: "refs/tags/"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_xenial.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k --pypi"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 CPU 16kHz PyPI training Py3.5"
description: "Training a DeepSpeech LDC93S1 model for Linux/AMD64 16kHz Python 3.5, CPU only, optimized version, decoder package from PyPI"

Просмотреть файл

@ -1,13 +0,0 @@
build:
template_file: test-linux-opt-base.tyml
dependencies:
- "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_xenial.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-unittests.sh 3.5.8:m"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech on Linux AMD64 CPU training unittests using Python 3.5"
description: "Training unittests DeepSpeech LDC93S1 model for Linux/AMD64 using Python 3.5, for CPU only, and optimized version"

Просмотреть файл

@ -1,13 +0,0 @@
build:
template_file: test-linux-opt-base.tyml
dependencies:
- "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_xenial.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 CPU 16kHz basic training Py3.5"
description: "Training a DeepSpeech LDC93S1 model for Linux/AMD64 16kHz Python 3.5, CPU only, optimized version"

Просмотреть файл

@ -1,13 +0,0 @@
build:
template_file: test-linux-opt-base.tyml
dependencies:
- "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_xenial.apt} ${python.packages_xenial.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-transcribe-tests.sh 3.5.8:m 16k"
workerType: "${docker.dsTests}"
metadata:
name: "DeepSpeech Linux AMD64 CPU 16kHz transcribe Py3.5"
description: "Transcribe a DeepSpeech LDC93S1 model for Linux/AMD64 16kHz Python 3.5, CPU only, optimized version"

Просмотреть файл

@ -9943,3 +9943,37 @@
fun:_ZN12TFModelState4initEPKc
fun:DS_CreateModel
}
{
<tensorflow_full>
Memcheck:Leak
match-leak-kinds: reachable
fun:_Znwm
fun:_ZN10tensorflow6thread16EigenEnvironment10CreateTaskESt8functionIFvvEE
fun:_ZN5Eigen15ThreadPoolTemplIN10tensorflow6thread16EigenEnvironmentEE16ScheduleWithHintESt8functionIFvvEEii
fun:_ZN5Eigen15ThreadPoolTemplIN10tensorflow6thread16EigenEnvironmentEE8ScheduleESt8functionIFvvEE
fun:_ZN10tensorflow6thread10ThreadPool8ScheduleESt8functionIFvvEE
fun:_ZZN10tensorflow13DirectSession11RunInternalExRKNS_10RunOptionsEPNS_18CallFrameInterfaceEPNS0_16ExecutorsAndKeysEPNS_11RunMetadataERKNS_6thread17ThreadPoolOptionsEENKUlSt8functionIFvvEEE4_clESG_
fun:_ZNSt17_Function_handlerIFvSt8functionIFvvEEEZN10tensorflow13DirectSession11RunInternalExRKNS4_10RunOptionsEPNS4_18CallFrameInterfaceEPNS5_16ExecutorsAndKeysEPNS4_11RunMetadataERKNS4_6thread17ThreadPoolOptionsEEUlS2_E4_E9_M_invokeERKSt9_Any_dataOS2_
fun:_ZNKSt8functionIFvS_IFvvEEEEclES1_
fun:_ZZN10tensorflow12_GLOBAL__N_113ExecutorStateINS_21SimplePropagatorStateEE6FinishEvENUlRKNS_6StatusEE1_clES6_
fun:_ZNSt17_Function_handlerIFvRKN10tensorflow6StatusEEZNS0_12_GLOBAL__N_113ExecutorStateINS0_21SimplePropagatorStateEE6FinishEvEUlS3_E1_E9_M_invokeERKSt9_Any_dataS3_
fun:_ZNKSt8functionIFvRKN10tensorflow6StatusEEEclES3_
fun:_ZN10tensorflow6Device4SyncERKSt8functionIFvRKNS_6StatusEEE
}
{
<tensorflow_full>
Memcheck:Leak
match-leak-kinds: reachable
fun:_Znwm
fun:_ZNSt14_Function_base13_Base_managerIZZN10tensorflow12_GLOBAL__N_113ExecutorStateINS1_21SimplePropagatorStateEE6FinishEvENUlRKNS1_6StatusEE1_clES8_EUlvE_E15_M_init_functorERSt9_Any_dataOSA_St17integral_constantIbLb0EE
fun:_ZNSt14_Function_base13_Base_managerIZZN10tensorflow12_GLOBAL__N_113ExecutorStateINS1_21SimplePropagatorStateEE6FinishEvENUlRKNS1_6StatusEE1_clES8_EUlvE_E15_M_init_functorERSt9_Any_dataOSA_
fun:_ZNSt8functionIFvvEEC1IZZN10tensorflow12_GLOBAL__N_113ExecutorStateINS3_21SimplePropagatorStateEE6FinishEvENUlRKNS3_6StatusEE1_clESA_EUlvE_vvEET_
fun:_ZZN10tensorflow12_GLOBAL__N_113ExecutorStateINS_21SimplePropagatorStateEE6FinishEvENUlRKNS_6StatusEE1_clES6_
fun:_ZNSt17_Function_handlerIFvRKN10tensorflow6StatusEEZNS0_12_GLOBAL__N_113ExecutorStateINS0_21SimplePropagatorStateEE6FinishEvEUlS3_E1_E9_M_invokeERKSt9_Any_dataS3_
fun:_ZNKSt8functionIFvRKN10tensorflow6StatusEEEclES3_
fun:_ZN10tensorflow6Device4SyncERKSt8functionIFvRKNS_6StatusEEE
fun:_ZN10tensorflow12_GLOBAL__N_113ExecutorStateINS_21SimplePropagatorStateEE6FinishEv
fun:_ZN10tensorflow12_GLOBAL__N_113ExecutorStateINS_21SimplePropagatorStateEE14ScheduleFinishEv
fun:_ZN10tensorflow12_GLOBAL__N_113ExecutorStateINS_21SimplePropagatorStateEE7ProcessENS2_10TaggedNodeEx
fun:_ZSt13__invoke_implIvRMN10tensorflow12_GLOBAL__N_113ExecutorStateINS0_21SimplePropagatorStateEEEFvNS3_10TaggedNodeExERPS4_JRS5_RxEET_St21__invoke_memfun_derefOT0_OT1_DpOT2_
}

Просмотреть файл

@ -1,10 +1,12 @@
import os
import io
import wave
import math
import tempfile
import collections
import ctypes
import io
import math
import numpy as np
import os
import pyogg
import tempfile
import wave
from .helpers import LimitingPool
from collections import namedtuple
@ -21,8 +23,9 @@ AUDIO_TYPE_NP = 'application/vnd.mozilla.np'
AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm'
AUDIO_TYPE_WAV = 'audio/wav'
AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus'
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
LOADABLE_AUDIO_EXTENSIONS = {'.wav': AUDIO_TYPE_WAV}
AUDIO_TYPE_OGG_OPUS = 'application/vnd.deepspeech.ogg_opus'
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, AUDIO_TYPE_OGG_OPUS]
OPUS_PCM_LEN_SIZE = 4
OPUS_RATE_SIZE = 4
@ -73,6 +76,8 @@ class Sample:
if audio_type in SERIALIZABLE_AUDIO_TYPES:
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
self.duration = read_duration(audio_type, self.audio)
if not self.audio_format:
self.audio_format = read_format(audio_type, self.audio)
else:
self.audio = raw_data
if self.audio_format is None:
@ -133,10 +138,11 @@ def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None,
yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples))
def get_audio_type_from_extension(ext):
if ext in LOADABLE_AUDIO_EXTENSIONS:
return LOADABLE_AUDIO_EXTENSIONS[ext]
return None
def get_loadable_audio_type_from_extension(ext):
return {
'.wav': AUDIO_TYPE_WAV,
'.opus': AUDIO_TYPE_OGG_OPUS,
}.get(ext, None)
def read_audio_format_from_wav_file(wav_file):
@ -340,6 +346,102 @@ def read_opus(opus_file):
return audio_format, bytes(audio_data)
def read_ogg_opus(ogg_file):
error = ctypes.c_int()
ogg_file_buffer = ogg_file.getbuffer()
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer),
ctypes.pointer(error)
)
if error.value != 0:
raise ValueError(
("Ogg/Opus buffer could not be read."
"Error code: {}").format(error.value)
)
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit
audio_format = AudioFormat(sample_rate, channel_count, sample_width)
# Allocate sufficient memory to store the entire PCM
pcm_size = pyogg.opus.op_pcm_total(opusfile, -1)
Buf = pyogg.opus.opus_int16*(pcm_size*channel_count)
buf = Buf()
# Create a pointer to the newly allocated memory. It
# seems we can only do pointer arithmetic on void
# pointers. See
# https://mattgwwalker.wordpress.com/2020/05/30/pointer-manipulation-in-python/
buf_ptr = ctypes.cast(
ctypes.pointer(buf),
ctypes.c_void_p
)
assert buf_ptr.value is not None # for mypy
buf_ptr_zero = buf_ptr.value
#: Bytes per sample
bytes_per_sample = ctypes.sizeof(pyogg.opus.opus_int16)
# Read through the entire file, copying the PCM into the
# buffer
samples = 0
while True:
# Calculate remaining buffer size
remaining_buffer = (
len(buf) # int
- (buf_ptr.value - buf_ptr_zero) // bytes_per_sample
)
# Convert buffer pointer to the desired type
ptr = ctypes.cast(
buf_ptr,
ctypes.POINTER(pyogg.opus.opus_int16)
)
# Read the next section of PCM
ns = pyogg.opus.op_read(
opusfile,
ptr,
remaining_buffer,
pyogg.ogg.c_int_p()
)
# Check for errors
if ns < 0:
raise ValueError(
"Error while reading OggOpus buffer. "+
"Error code: {}".format(ns)
)
# Increment the pointer
buf_ptr.value += (
ns
* bytes_per_sample
* channel_count
)
assert buf_ptr.value is not None # for mypy
samples += ns
# Check if we've finished
if ns == 0:
break
# Close the open file
pyogg.opus.op_free(opusfile)
# Cast buffer to a one-dimensional array of chars
#: Raw PCM data from audio file.
CharBuffer = ctypes.c_byte * (bytes_per_sample * channel_count * pcm_size)
audio_data = CharBuffer.from_buffer(buf)
return audio_format, audio_data
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
# wav_file is already a file-pointer here
with wave.open(wav_file, 'wb') as wav_file_writer:
@ -362,6 +464,8 @@ def read_audio(audio_type, audio_file):
return read_wav(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
@ -384,11 +488,83 @@ def read_opus_duration(opus_file):
return get_pcm_duration(pcm_buffer_size, audio_format)
def read_ogg_opus_duration(ogg_file):
error = ctypes.c_int()
ogg_file_buffer = ogg_file.getbuffer()
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer),
ctypes.pointer(error)
)
if error.value != 0:
raise ValueError(
("Ogg/Opus buffer could not be read."
"Error code: {}").format(error.value)
)
pcm_buffer_size = pyogg.opus.op_pcm_total(opusfile, -1)
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit
audio_format = AudioFormat(sample_rate, channel_count, sample_width)
pyogg.opus.op_free(opusfile)
return get_pcm_duration(pcm_buffer_size, audio_format)
def read_duration(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav_duration(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus_duration(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus_duration(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
def read_wav_format(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
return read_audio_format_from_wav_file(wav_file_reader)
def read_opus_format(opus_file):
_, audio_format = read_opus_header(opus_file)
return audio_format
def read_ogg_opus_format(ogg_file):
error = ctypes.c_int()
ogg_file_buffer = ogg_file.getbuffer()
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer),
ctypes.pointer(error)
)
if error.value != 0:
raise ValueError(
("Ogg/Opus buffer could not be read."
"Error code: {}").format(error.value)
)
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
pyogg.opus.op_free(opusfile)
sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit
return AudioFormat(sample_rate, channel_count, sample_width)
def read_format(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav_format(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus_format(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus_format(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type))

Просмотреть файл

@ -3,6 +3,7 @@ import os
import re
import math
import random
import resampy
import numpy as np
from multiprocessing import Queue, Process
@ -129,7 +130,7 @@ def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, cl
Tensor of type float32
The augmented spectrogram
"""
if augmentations is not None:
if augmentations:
for augmentation in augmentations:
if isinstance(augmentation, GraphAugmentation):
tensor = augmentation.maybe_apply(domain, tensor, transcript=transcript, clock=clock)
@ -348,24 +349,25 @@ class Resample(SampleAugmentation):
self.rate = int_range(rate)
def apply(self, sample, clock=0.0):
# late binding librosa and its dependencies
# pre-importing sklearn fixes https://github.com/scikit-learn/scikit-learn/issues/14485
import sklearn # pylint: disable=import-outside-toplevel
from librosa.core import resample # pylint: disable=import-outside-toplevel
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
rate = pick_value_from_range(self.rate, clock=clock)
audio = sample.audio
orig_len = len(audio)
audio = np.swapaxes(audio, 0, 1)
if audio.shape[0] < 2:
# since v0.8 librosa enforces a shape of (samples,) instead of (channels, samples) for mono samples
resampled = resample(audio[0], sample.audio_format.rate, rate)
audio[0] = resample(resampled, rate, sample.audio_format.rate)[:orig_len]
else:
audio = resample(audio, sample.audio_format.rate, rate)
audio = resample(audio, rate, sample.audio_format.rate)
audio = np.swapaxes(audio, 0, 1)[0:orig_len]
sample.audio = audio
orig_len = len(sample.audio)
resampled = resampy.resample(sample.audio, sample.audio_format.rate, rate, axis=0, filter='kaiser_fast')
sample.audio = resampy.resample(resampled, rate, sample.audio_format.rate, axis=0, filter='kaiser_fast')[:orig_len]
class NormalizeSampleRate(SampleAugmentation):
def __init__(self, rate):
super().__init__(p=1.0)
self.rate = rate
def apply(self, sample, clock=0.0):
if sample.audio_format.rate == self.rate:
return
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
sample.audio = resampy.resample(sample.audio, sample.audio_format.rate, self.rate, axis=0, filter='kaiser_fast')
sample.audio_format = sample.audio_format._replace(rate=self.rate)
class Volume(SampleAugmentation):

Просмотреть файл

@ -12,7 +12,7 @@ from .flags import FLAGS
from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .helpers import parse_file_size
from .augmentations import parse_augmentations
from .augmentations import parse_augmentations, NormalizeSampleRate
from .io import path_exists_remote
class ConfigSingleton:
@ -33,11 +33,14 @@ def initialize_globals():
# Augmentations
c.augmentations = parse_augmentations(FLAGS.augment)
if len(c.augmentations) > 0 and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0:
if c.augmentations and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0:
log_warn('Due to current feature-cache settings the exact same sample augmentations of the first '
'epoch will be repeated on all following epochs. This could lead to unintended over-fitting. '
'You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs.')
if FLAGS.normalize_sample_rate:
c.augmentations = [NormalizeSampleRate(FLAGS.audio_sample_rate)] + c['augmentations']
# Caching
if FLAGS.cache_for_epochs == 1:
log_warn('--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it.')

Просмотреть файл

@ -28,7 +28,7 @@ def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phas
lambda: tf.no_op(),
name='matching_sample_rate')
if train_phase and augmentations is not None:
if train_phase and augmentations:
audio = apply_graph_augmentations('signal', audio, augmentations, transcript=transcript, clock=clock)
spectrogram = contrib_audio.audio_spectrogram(audio,
@ -36,7 +36,7 @@ def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phas
stride=Config.audio_step_samples,
magnitude_squared=True)
if train_phase and augmentations is not None:
if train_phase and augmentations:
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, transcript=transcript, clock=clock)
features = contrib_audio.mfcc(spectrogram=spectrogram,
@ -45,7 +45,7 @@ def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phas
upper_frequency_limit=FLAGS.audio_sample_rate / 2)
features = tf.reshape(features, [-1, Config.n_input])
if train_phase and augmentations is not None:
if train_phase and augmentations:
features = apply_graph_augmentations('features', features, augmentations, transcript=transcript, clock=clock)
return features, tf.shape(input=features)[0]

Просмотреть файл

@ -24,6 +24,7 @@ def create_flags():
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
f.DEFINE_boolean('normalize_sample_rate', True, 'normalize sample rate of all train_files to --audio_sample_rate')
# Data Augmentation
# ================

Просмотреть файл

@ -11,11 +11,10 @@ from functools import partial
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
from .audio import (
Sample,
DEFAULT_FORMAT,
AUDIO_TYPE_PCM,
AUDIO_TYPE_OPUS,
SERIALIZABLE_AUDIO_TYPES,
get_audio_type_from_extension,
get_loadable_audio_type_from_extension,
write_wav
)
from .io import open_remote, is_remote_path
@ -40,7 +39,7 @@ CONTENT_TYPE_TRANSCRIPT = 'transcript'
class LabeledSample(Sample):
"""In-memory labeled audio sample representing an utterance.
Derived from util.audio.Sample and used by sample collection readers and writers."""
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
def __init__(self, audio_type, raw_data, transcript, audio_format=None, sample_id=None):
"""
Parameters
----------
@ -110,7 +109,7 @@ def load_sample(filename, label=None):
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
"""
ext = os.path.splitext(filename)[1].lower()
audio_type = get_audio_type_from_extension(ext)
audio_type = get_loadable_audio_type_from_extension(ext)
if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext))
return PackedSample(filename, audio_type, label)