Refactoring of TF based augmentations

This commit is contained in:
Tilman Kamp 2020-06-10 13:42:45 +02:00
Родитель bfaa68945a
Коммит d94db7ca43
15 изменённых файлов: 820 добавлений и 634 удалений

Просмотреть файл

@ -10,7 +10,8 @@ import random
import argparse
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source
from deepspeech_training.util.augmentations import parse_augmentations, apply_sample_augmentations, SampleAugmentation
def get_samples_in_play_order():
@ -38,12 +39,15 @@ def get_samples_in_play_order():
def play_collection():
augmentations = parse_augmentations(CLI_ARGS.augment)
if any(map(lambda a: not isinstance(a, SampleAugmentation), augmentations)):
print("Warning: Some of the augmentations cannot be simulated by this command.")
samples = get_samples_in_play_order()
samples = augment_samples(samples,
audio_type=AUDIO_TYPE_PCM,
augmentation_specs=CLI_ARGS.augment,
process_ahead=0,
fixed_clock=CLI_ARGS.clock)
samples = apply_sample_augmentations(samples,
audio_type=AUDIO_TYPE_PCM,
augmentations=augmentations,
process_ahead=0,
clock=CLI_ARGS.clock)
for sample in samples:
if not CLI_ARGS.quiet:
print('Sample "{}"'.format(sample.sample_id), file=sys.stderr)

Просмотреть файл

@ -0,0 +1,28 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_csv} --train_batch_size 1 \
--augment speed \
--augment dropout \
--augment pitch_and_tempo \
--augment time_mask \
--augment frequency_mask \
--augment add \
--augment multiply \
--augment warp \
--n_hidden 100 \
--epochs 1

Просмотреть файл

@ -41,12 +41,6 @@ if ! $compare --if-differ "${ldc93s1_wav}" /tmp/reverb-test.wav; then
exit 1
fi
$play ${ldc93s1_wav} --augment gaps[n=10,size=100.0] --pipe >/tmp/gaps-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/gaps-test.wav; then
echo "Gaps augmentation had no effect or changed basic sample properties"
exit 1
fi
$play ${ldc93s1_wav} --augment resample[rate=4000] --pipe >/tmp/resample-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/resample-test.wav; then
echo "Resample augmentation had no effect or changed basic sample properties"

Просмотреть файл

@ -270,12 +270,6 @@ Augmentation
Augmentation is a useful technique for better generalization of machine learning models. Thus, a pre-processing pipeline with various augmentation techniques on raw pcm and spectrogram has been implemented and can be used while training the model. Following are the available augmentation techniques that can be enabled at training time by using the corresponding flags in the command line.
Audio Augmentation
------------------
Augmentations that are applied before potential feature caching can be specified through the ``--augment`` flag. Being a multi-flag, it can be specified multiple times (see below for an example).
Each sample of the training data will get treated by every specified augmentation in their given order. However: whether an augmentation will actually get applied to a sample is decided by chance on base of the augmentation's probability value. For example a value of ``p=0.1`` would apply the according augmentation to just 10% of all samples. This also means that augmentations are not mutually exclusive on a per-sample basis.
The ``--augment`` flag uses a common syntax for all augmentation types:
@ -297,14 +291,31 @@ In the documentation below, whenever a value is specified as ``<float-range>`` o
* ``<value>~<r>``: A center value with a randomization radius around it. E.g. ``1.2~0.4`` will result in picking of a uniformly random value between 0.8 and 1.6 on each sample augmentation.
* ``<start>:<end>``: The value will range from `<start>` at the beginning of an epoch to `<end>` at the end of an epoch. E.g. ``-0.2:1.2`` (float) or ``2000:4000`` (int)
* ``<start>:<end>``: The value will range from `<start>` at the beginning of the training to `<end>` at the end of the training. E.g. ``-0.2:1.2`` (float) or ``2000:4000`` (int)
* ``<start>:<end>~<r>``: Combination of the two previous cases with a ranging center value. E.g. ``4-6~2`` would at the beginning of an epoch pick values between 2 and 6 and at the end of an epoch between 4 and 8.
* ``<start>:<end>~<r>``: Combination of the two previous cases with a ranging center value. E.g. ``4-6~2`` would at the beginning of the training pick values between 2 and 6 and at the end of the training between 4 and 8.
Ranges specified with integer limits will only assume integer (rounded) values.
If feature caching is enabled, these augmentations will only be performed on the first epoch and the result will be reused for subsequent epochs. The flag ``--augmentations_per_epoch N`` (by default `N` is 1) could be used to get more than one epoch worth of augmentations into the cache. During training, each epoch will do ``N`` passes over the training set, each time performing augmentation independently of previous passes. Be aware: this will also multiply the required size of the feature cache if it's enabled.
.. warning::
If feature caching is enabled and infinite (default), these augmentations will only be performed on first epoch and the result will be reused for subsequent epochs. This would not only hinder value ranges from reaching their intended final values, but could also lead to unintended over-fitting. In this case flag ``--cache_for_epochs N`` (with N > 1) should be used to periodically invalidate the cache and thus allow samples to be re-augmented in new ways and with current range-values.
Every augmentation is targeting a certain data representation of the sample - further on called *domain*.
Augmentations are applied domain-wise in the following order:
1. **sample** domain: The sample just got loaded and its waveform is represented as a NumPy array. For implementation reasons these augmentations are the only ones that can be "simulated" through ``bin/play.py``.
2. **signal** domain: The sample waveform is represented as a tensor.
3. **spectrogram** domain: The sample spectrogram is represented as a tensor.
4. **features** domain: The sample's MEL spectrogram features are represented as a tensor.
During each phase augmentations are applied in command-line order (the **warp** augmentation being the only exception).
Sample domain augmentations
---------------------------
**Overlay augmentation** ``--augment overlay[p=<float>,source=<str>,snr=<float-range>,layers=<int-range>]``
Layers another audio source (multiple times) onto augmented samples.
@ -328,16 +339,6 @@ If feature caching is enabled, these augmentations will only be performed on the
* **decay**: sound decay in dB per reflection - higher values will result in a less reflective perceived "room"
**Gaps augmentation** ``--augment gaps[p=<float>,n=<int-range>,size=<float-range>]``
Sets time-intervals within the augmented samples to zero (silence) at random positions.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **n**: number of intervals to set to zero
* **size**: duration of intervals in ms
**Resample augmentation** ``--augment resample[p=<float>,rate=<int-range>]``
Resamples augmented samples to another sample rate and then resamples back to the original sample rate.
@ -361,6 +362,96 @@ If feature caching is enabled, these augmentations will only be performed on the
* **dbfs** : target volume in dBFS (default value of 3.0103 will normalize min and max amplitudes to -1.0/1.0)
Spectrogram domain augmentations
--------------------------------
**Pitch and tempo augmentation** ``--augment pitch_and_tempo[p=<float>,pitch=<float-range>,tempo=<float-range>]``
Scales spectrogram on time and frequency axis and thus changes pitch and playback tempo.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **pitch**: pitch factor by with the frequency axis is scaled (e.g. a value of 2.0 will raise audio frequency by one octave)
* **tempo**: tempo factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
**Speed augmentation** ``--augment speed[p=<float>,factor=<float-range>]``
Scales spectrogram on time axis and thus changes playback tempo.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **factor**: speed factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
**Warp augmentation** ``--augment warp[p=<float>,shift=<float-range>,order=<int-range>,nbp=<int-range>,ncp=<int-range>,regularization_weight=<float>]``
Applies a non-linear image warp to the spectrogram, where the warp is specified by the source and destination locations of a (potentially small) number of control points. Of all specified spectrogram augmentations this one will always be applied first.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **shift**: maximum shift distance of control points on time axis in ms
* **order**: polynomial order used by the spline interpolation
* **nbp**: how many zero-flow boundary points to include at each spectrogram edge
* **ncp**: how many control points to warp inside the spectrogram
* **regularization_weight**: weight on smoothness regularizer in interpolation
**Frequency mask augmentation** ``--augment frequency_mask[p=<float>,n=<int-range>,size=<int-range>]``
Sets frequency-intervals within the augmented samples to zero (silence) at random frequencies.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **n**: number of intervals to mask
* **size**: number of frequency bands to mask per interval
Multi domain augmentations
--------------------------
**Time mask augmentation** ``--augment time_mask[p=<float>,n=<int-range>,size=<float-range>,domain=<domain>]``
Sets time-intervals within the augmented samples to zero (silence) at random positions.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **n**: number of intervals to set to zero
* **size**: duration of intervals in ms
* **domain**: data representation to apply augmentation to - "signal", "features" or "spectrogram" (default)
**Dropout augmentation** ``--augment dropout[p=<float>,rate=<float-range>,domain=<domain>]``
Zeros random data points of the targeted data representation.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **rate**: dropout rate ranging from 0.0 for no dropout to 1.0 for 100% dropout
* **domain**: data representation to apply augmentation to - "signal", "features" or "spectrogram" (default)
**Add augmentation** ``--augment add[p=<float>,stddev=<float-range>,domain=<domain>]``
Adds random values picked from a normal distribution (with a mean of 0.0) to all data points of the targeted data representation.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **stddev**: standard deviation of the normal distribution to pick values from
* **domain**: data representation to apply augmentation to - "signal", "features" (default) or "spectrogram"
**Multiply augmentation** ``--augment multiply[p=<float>,stddev=<float-range>,domain=<domain>]``
Multiplies all data points of the targeted data representation with random values picked from a normal distribution (with a mean of 1.0).
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **stddev**: standard deviation of the normal distribution to pick values from
* **domain**: data representation to apply augmentation to - "signal", "features" (default) or "spectrogram"
Example training with all augmentations:
@ -368,18 +459,26 @@ Example training with all augmentations:
python -u DeepSpeech.py \
--train_files "train.sdb" \
--augmentations_per_epoch 10 \
--feature_cache ./feature.cache \
--cache_for_epochs 10 \
--epochs 100 \
--augment overlay[p=0.5,source=noise.sdb,layers=1,snr=50:20~10] \
--augment overlay[p=0.2,source=voices.sdb,layers=10:6,snr=50:20~10] \
--augment reverb[p=0.1,delay=50.0~30.0,decay=10.0:2.0~1.0] \
--augment gaps[p=0.05,n=1:3~2,size=10:100] \
--augment resample[p=0.1,rate=12000:8000~4000] \
--augment codec[p=0.1,bitrate=48000:16000] \
--augment volume[p=0.1,dbfs=-10:-40] \
--augment pitch_and_tempo[p=0.1,pitch=1~0.2,tempo=1~0.2] \
--augment speed[p=0.1,factor=1~0.5] \
--augment warp[p=0.1,shift=30:60~20,ncp=4~3] \
--augment frequency_mask[p=0.1,n=1:3,size=1:5] \
--augment time_mask[p=0.1,domain=signal,n=3:10~2,size=50:100~40] \
--augment dropout[p=0.1,rate=0.05] \
--augment add[p=0.1,domain=signal,stddev=0~0.5] \
--augment multiply[p=0.1,domain=features,stddev=0~0.5] \
[...]
The ``bin/play.py`` tool also supports ``--augment`` parameters and can be used for experimenting with different configurations.
The ``bin/play.py`` tool also supports ``--augment`` parameters (for sample domain augmentations) and can be used for experimenting with different configurations.
Example of playing all samples with reverberation and maximized volume:
@ -393,42 +492,3 @@ Example simulation of the codec augmentation of a wav-file first at the beginnin
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 0.0 test.wav
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 1.0 test.wav
The following augmentations are applied after feature caching, hence the way they are applied will not repeat epoch-wise.
Working on spectrogram and feature level, `bin/play.py` offers no ability to simulate them.
#. **Standard deviation for Gaussian additive noise:** ``--data_aug_features_additive``
#. **Standard deviation for Normal distribution around 1 for multiplicative noise:** ``--data_aug_features_multiplicative``
#. **Standard deviation for speeding-up tempo. If Standard deviation is 0, this augmentation is not performed:** ``--augmentation_speed_up_std``
Spectrogram Augmentation
------------------------
Inspired by Google Paper on `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition <https://arxiv.org/abs/1904.08779>`_
#.
**Keep rate of dropout augmentation on a spectrogram (if 1, no dropout will be performed on the spectrogram)**\ :
* Keep Rate : ``--augmentation_spec_dropout_keeprate value between range [0 - 1]``
#.
**Whether to use frequency and time masking augmentation:**
* Enable / Disable : ``--augmentation_freq_and_time_masking / --noaugmentation_freq_and_time_masking``
* Max range of masks in the frequency domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_freq_mask_range eg: 5``
* Number of masks in the frequency domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_number_freq_masks eg: 3``
* Max range of masks in the time domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_time_mask_range eg: 2``
* Number of masks in the time domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_number_time_masks eg: 3``
#.
**Whether to use spectrogram speed and tempo scaling:**
* Enable / Disable : ``--augmentation_pitch_and_tempo_scaling / --noaugmentation_pitch_and_tempo_scaling``
* Min value of pitch scaling: ``--augmentation_pitch_and_tempo_scaling_min_pitch eg:0.95``
* Max value of pitch scaling: ``--augmentation_pitch_and_tempo_scaling_max_pitch eg:1.2``
* Max value of tempo scaling: ``--augmentation_pitch_and_tempo_scaling_max_tempo eg:1.2``

Просмотреть файл

@ -22,7 +22,8 @@ popd
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-signal_augmentations.sh
time ./bin/run-tc-sample_augmentations.sh
time ./bin/run-tc-graph_augmentations.sh
popd
virtualenv_deactivate "${pyalias}" "deepspeech"

Просмотреть файл

@ -6,7 +6,7 @@ build:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-signal_augmentation-tests.sh 3.6.10:m"
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-augmentation-tests.sh 3.6.10:m"
metadata:
name: "DeepSpeech Linux AMD64 CPU signal augmentations Py3.6"
description: "Augmenting LDC93S1 sample in different ways for Linux/AMD64 16kHz Python 3.6, CPU only, optimized version"

Просмотреть файл

@ -10,7 +10,6 @@ DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.a
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
@ -32,7 +31,7 @@ from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
@ -407,26 +406,13 @@ def log_grads_and_vars(grads_and_vars):
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
repetitions=FLAGS.augmentations_per_epoch,
augmentation_specs=FLAGS.augment,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
epochs=FLAGS.epochs,
augmentations=Config.augmentations,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
@ -541,6 +527,12 @@ def train():
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache is not None:
feature_cache_index = FLAGS.feature_cache + '.index'
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
log_info('Invalidating feature cache')
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
@ -567,11 +559,6 @@ def train():
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
@ -680,7 +667,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]

Просмотреть файл

@ -0,0 +1,564 @@
import os
import re
import math
import random
import numpy as np
from multiprocessing import Queue, Process
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
BUFFER_SIZE = 1 * MEGABYTE
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
class Augmentation:
def __init__(self, p=1.0):
self.probability = float(p)
class SampleAugmentation(Augmentation):
def start(self, buffering=BUFFER_SIZE):
pass
def apply(self, sample, clock=0.0):
raise NotImplementedError
def stop(self):
pass
class GraphAugmentation(Augmentation):
def __init__(self, p=1.0, domain='spectrogram'):
super(GraphAugmentation, self).__init__(p)
if domain not in ['signal', 'spectrogram', 'features']:
raise ValueError('Unsupported augmentation domain: {}'.format(domain))
self.domain = domain
def apply(self, tensor, clock=0.0):
raise NotImplementedError
def apply_with_probability(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max))
return tf.cond(tf.less(rv, self.probability),
lambda: self.apply(tensor, clock=clock),
lambda: tensor)
def maybe_apply(self, domain, tensor, clock=0.0):
if domain == self.domain:
return self.apply_with_probability(tensor, clock=clock)
return tensor
def parse_augmentation(augmentation_spec):
"""
Parses an augmentation specification.
Parameters
----------
augmentation_spec : str
Augmentation specification like "reverb[delay=20.0,decay=1.0]".
Returns
-------
Instance of an augmentation class from util.augmentations.*.
"""
match = SPEC_PARSER.match(augmentation_spec)
if not match:
raise ValueError('Augmentation specification has wrong format')
cls_name = ''.join(map(lambda p: p[0].upper() + p[1:], match.group('cls').split('_')))
augmentation_cls = globals()[cls_name] if cls_name in globals() else None
if augmentation_cls is None or not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation:
raise ValueError('Unknown augmentation: {}'.format(cls_name))
parameters = match.group('params')
parameters = [] if parameters is None else parameters.split(',')
args = []
kwargs = {}
for parameter in parameters:
pair = tuple(list(map(str.strip, (parameter.split('=')))))
if len(pair) == 1:
args.append(pair)
elif len(pair) == 2:
kwargs[pair[0]] = pair[1]
else:
raise ValueError('Unable to parse augmentation value assignment')
return augmentation_cls(*args, **kwargs)
def parse_augmentations(augmentation_specs):
"""
Parses an augmentation specification.
Parameters
----------
augmentation_specs : list of str
List of augmentation specifications like ["reverb[delay=20.0,decay=1.0]", "volume"].
Returns
-------
List of augmentation class instances from util.augmentations.*.
"""
return [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
"""
Augments training sample tensor of a certain domain with matching augmentations of passed list.
Parameters
----------
domain : str
Domain of the tensor to apply augmentations to. One of "signal", "spectrogram" or "features"
tensor : Tensor of type float32
Tensor to apply augmentations to.
augmentations : list of augmentation class instances from util.augmentations.*.
List of augmentations of which only the spectrogram ones will get applied to the samples.
clock : Tensor of type float32
Time indicator for augmentation value-ranges. Running from 0.0 (start of training) to 1.0 (end of training).
Returns
-------
Tensor of type float32
The augmented spectrogram
"""
if augmentations is not None:
# Warp has to come before any spectrogram masking
for augmentation in augmentations:
if isinstance(augmentation, Warp):
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
for augmentation in augmentations:
if isinstance(augmentation, GraphAugmentation) and not isinstance(augmentation, Warp):
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
return tensor
class AugmentationContext:
def __init__(self, target_audio_type, augmentations):
self.target_audio_type = target_audio_type
self.augmentations = augmentations
AUGMENTATION_CONTEXT = None
def _init_augmentation_worker(preparation_context):
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
AUGMENTATION_CONTEXT = preparation_context
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
for augmentation in context.augmentations:
if random.random() < augmentation.probability:
augmentation.apply(sample, clock)
sample.change_audio_type(new_audio_type=context.target_audio_type)
return sample
def apply_sample_augmentations(samples,
augmentations,
audio_type=AUDIO_TYPE_NP,
buffering=BUFFER_SIZE,
process_ahead=None,
clock=0.0,
final_clock=None):
"""
Prepares samples for being used during training.
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
Parameters
----------
samples : Sample enumeration
Typically produced by util.sample_collections.samples_from_sources.
augmentations : list of augmentation class instances from util.augmentations.*.
List of augmentations of which only the signal ones will get applied to the samples.
audio_type : str
Target audio-type to convert samples to. See util.audio.Sample.__init__ .
buffering : int
Read-buffer size to use while reading files.
process_ahead : int
Number of samples to pre-process ahead of time.
clock : float
Start or fixed clock value between 0.0 and 1.0 for the first or all samples. Has to be <= than clock_to.
final_clock : float
Final clock value between 0.0 and 1.0 for the last sample. Has to be >= than clock.
Requires samples.__len__ attribute.
Returns
-------
iterable of util.sample_collections.LabeledSample or util.audio.Sample
"""
def timed_samples():
if final_clock is None:
for sample in samples:
yield sample, clock
else:
for sample_index, sample in enumerate(samples):
sample_clock = clock + (final_clock - clock) * (sample_index / len(samples))
yield sample, sample_clock
assert 0.0 <= clock <= 1.0
if final_clock is not None:
assert 0.0 <= final_clock <= 1.0
assert clock <= final_clock
augmentations = list(filter(lambda aug: isinstance(aug, SampleAugmentation), augmentations))
try:
for augmentation in augmentations:
augmentation.start(buffering=buffering)
context = AugmentationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _augment_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_augmentation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
augmentation.stop()
def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
"""
As the central distribution point for overlay samples this function is supposed to run in one process only.
This ensures that samples are not used twice if not required.
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
These are then doing decompression, potential conversion and overlaying in parallel.
"""
# preventing cyclic import problems
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
while True:
for sample in samples:
queue.put(sample)
class Overlay(SampleAugmentation):
"""See "Overlay augmentation" in TRAINING.rst"""
def __init__(self, source, p=1.0, snr=3.0, layers=1):
super(Overlay, self).__init__(p)
self.source = source
self.snr = float_range(snr)
self.layers = int_range(layers)
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
self.current_sample = None
self.enqueue_process = None
def start(self, buffering=BUFFER_SIZE):
self.enqueue_process = Process(target=_enqueue_overlay_samples,
args=(self.source, self.queue),
kwargs={'buffering': buffering})
self.enqueue_process.start()
def apply(self, sample, clock=0.0):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
n_layers = pick_value_from_range(self.layers, clock=clock)
audio = sample.audio
overlay_data = np.zeros_like(audio)
for _ in range(n_layers):
overlay_offset = 0
while overlay_offset < len(audio):
if self.current_sample is None:
next_overlay_sample = self.queue.get()
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
self.current_sample = next_overlay_sample.audio
n_required = len(audio) - overlay_offset
n_current = len(self.current_sample)
if n_required >= n_current: # take it completely
overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample
overlay_offset += n_current
self.current_sample = None
else: # take required slice from head and keep tail for next layer or sample
overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required]
overlay_offset += n_required
self.current_sample = self.current_sample[n_required:]
snr_db = pick_value_from_range(self.snr, clock=clock)
orig_dbfs = max_dbfs(audio)
overlay_gain = orig_dbfs - max_dbfs(overlay_data) - snr_db
audio += overlay_data * gain_db_to_ratio(overlay_gain)
sample.audio = normalize_audio(audio, dbfs=orig_dbfs)
def stop(self):
if self.enqueue_process is not None:
self.enqueue_process.terminate()
class Codec(SampleAugmentation):
"""See "Codec augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, bitrate=3200):
super(Codec, self).__init__(p)
self.bitrate = int_range(bitrate)
def apply(self, sample, clock=0.0):
bitrate = pick_value_from_range(self.bitrate, clock=clock)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream
class Reverb(SampleAugmentation):
"""See "Reverb augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, delay=20.0, decay=10.0):
super(Reverb, self).__init__(p)
self.delay = float_range(delay)
self.decay = float_range(decay)
def apply(self, sample, clock=0.0):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
audio = np.array(sample.audio, dtype=np.float64)
orig_dbfs = max_dbfs(audio)
delay = pick_value_from_range(self.delay, clock=clock)
decay = pick_value_from_range(self.decay, clock=clock)
decay = gain_db_to_ratio(-decay)
result = np.copy(audio)
primes = [17, 19, 23, 29, 31]
for delay_prime in primes: # primes to minimize comb filter interference
layer = np.copy(audio)
n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0)
n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero
for w_index in range(0, math.floor(len(audio) / n_delay)):
w1 = w_index * n_delay
w2 = (w_index + 1) * n_delay
width = min(len(audio) - w2, n_delay) # last window could be smaller
layer[w2:w2 + width] += decay * layer[w1:w1 + width]
result += layer
audio = normalize_audio(result, dbfs=orig_dbfs)
sample.audio = np.array(audio, dtype=np.float32)
class Resample(SampleAugmentation):
"""See "Resample augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, rate=8000):
super(Resample, self).__init__(p)
self.rate = int_range(rate)
def apply(self, sample, clock=0.0):
# late binding librosa and its dependencies
from librosa.core import resample # pylint: disable=import-outside-toplevel
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
rate = pick_value_from_range(self.rate, clock=clock)
audio = sample.audio
orig_len = len(audio)
audio = np.swapaxes(audio, 0, 1)
audio = resample(audio, sample.audio_format.rate, rate)
audio = resample(audio, rate, sample.audio_format.rate)
audio = np.swapaxes(audio, 0, 1)[0:orig_len]
sample.audio = audio
class Volume(SampleAugmentation):
"""See "Volume augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, dbfs=3.0103):
super(Volume, self).__init__(p)
self.target_dbfs = float_range(dbfs)
def apply(self, sample, clock=0.0):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
class PitchAndTempo(GraphAugmentation):
"""See "Pitch and tempo augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, tempo=1.2, pitch=(1.075, 1.075, 0.125)):
super(PitchAndTempo, self).__init__(p, domain='spectrogram')
self.tempo = float_range(tempo)
self.pitch = float_range(pitch)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor)
pitch = tf_pick_value_from_range(self.pitch, clock=clock)
tempo = tf.math.maximum(1.0, tf_pick_value_from_range(self.tempo, clock=clock))
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / tempo, tf.int32)
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, new_freq_size])
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug,
offset_height=0,
offset_width=0,
target_height=tf.shape(spectrogram_aug)[1],
target_width=tf.math.minimum(original_shape[2], new_freq_size))
spectrogram_aug = tf.cond(pitch < 1,
lambda: tf.image.pad_to_bounding_box(spectrogram_aug,
offset_height=0,
offset_width=0,
target_height=tf.shape(spectrogram_aug)[1],
target_width=original_shape[2]),
lambda: spectrogram_aug)
return spectrogram_aug[:, :, :, 0]
class Speed(GraphAugmentation):
"""See "Speed augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, factor=1.1):
super(Speed, self).__init__(p, domain='spectrogram')
self.factor = float_range(factor)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
factor = tf_pick_value_from_range(self.factor, clock=clock)
original_shape = tf.shape(tensor)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32)
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]])
return spectrogram_aug[:, :, :, 0]
class Warp(GraphAugmentation):
"""See "Warp augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, shift=100.0, order=3, nbp=1, ncp=1, regularization_weight=0.0):
super(Warp, self).__init__(p, domain='spectrogram')
self.shift = float_range(shift)
self.order = int_range(order)
self.nbp = int_range(nbp)
self.ncp = int_range(ncp)
# Making this a value-range is impossible, as it would get a tensor which would downstream be used as parameter
# of a comparison inside tensorflow.contrib.image.python.ops.interpolate_spline. This is not supported.
self.regularization_weight = float(regularization_weight)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
from .flags import FLAGS # pylint: disable=import-outside-toplevel
from .sparse_image_warp import sparse_image_warp # pylint: disable=import-outside-toplevel
# reshape to fit `sparse_image_warp`'s input shape (1, time steps, freq, 1), batch_size must be 1
expanded_spectrogram = tf.expand_dims(tensor, -1)
original_shape = tf.shape(expanded_spectrogram)
tau, freq_size = original_shape[1], original_shape[2]
seed = (clock * tf.int32.min, clock * tf.int32.max)
shift = tf_pick_value_from_range(self.shift, clock=clock)
shift *= FLAGS.audio_sample_rate / (FLAGS.feature_win_step * 1000.0) # number of windows
shift = tf.math.minimum(tf.cast(shift, dtype=tf.int32), tf.math.floordiv(tau, 2) - 1) # to protect short audio
nbp = tf_pick_value_from_range(self.nbp, clock=clock)
ncp = tf_pick_value_from_range(self.ncp, clock=clock)
# workaround for missing stateless shuffle support
frequencies = tf.random.stateless_uniform([2 * ncp], seed, minval=1, maxval=freq_size - 2, dtype=tf.int32)
frequencies = tf.unique(tf.concat([frequencies, tf.range(1, limit=freq_size - 3)], axis=0))[0][0:ncp]
source_max = tau - shift
source_min = tf.math.minimum(source_max - ncp, shift)
# workaround for missing stateless shuffle support
src_times = tf.random.stateless_uniform([2 * ncp], seed, minval=source_min, maxval=source_max, dtype=tf.int32)
src_times = tf.unique(tf.concat([src_times, tf.range(1, limit=source_max)], axis=0))[0][0:ncp]
dst_times = src_times + tf.random.stateless_uniform([ncp], seed, minval=-shift, maxval=shift, dtype=tf.int32)
scp_locations = tf.cast([tf.transpose(tf.stack([src_times, frequencies]))], dtype=tf.float32)
dcp_locations = tf.cast([tf.transpose(tf.stack([dst_times, frequencies]))], dtype=tf.float32)
order = tf_pick_value_from_range(self.order, clock=clock)
order = tf.math.maximum(3, order) # prevents "Input matrix is not invertible." exception
order = tf.cast(order, tf.float32)
spectrogram_aug, _ = sparse_image_warp(expanded_spectrogram,
source_control_point_locations=scp_locations,
dest_control_point_locations=dcp_locations,
interpolation_order=order,
regularization_weight=self.regularization_weight,
num_boundary_points=nbp)
return tf.reshape(spectrogram_aug, shape=(1, -1, freq_size))
class FrequencyMask(GraphAugmentation):
"""See "Frequency mask augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, n=3, size=2):
super(FrequencyMask, self).__init__(p, domain='spectrogram')
self.n = int_range(n) # pylint: disable=invalid-name
self.size = int_range(size)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
time_max = tf.shape(tensor)[1]
freq_max = tf.shape(tensor)[2]
n = tf_pick_value_from_range(self.n, clock=clock)
def body(i, spectrogram_aug):
size = tf_pick_value_from_range(self.size, clock=clock)
size = tf.math.maximum(1, tf.math.minimum(freq_max - 1, size))
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
f0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=freq_max - size, dtype=tf.dtypes.int32)
freq_mask = tf.concat([tf.ones([1, time_max, f0]),
tf.zeros([1, time_max, size]),
tf.ones([1, time_max, freq_max - f0 - size])], axis=2)
return i + 1, spectrogram_aug * freq_mask
return tf.while_loop(lambda i, spectrogram_aug: i < n, body, (0, tensor))[1]
class TimeMask(GraphAugmentation):
"""See "Time mask augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0):
super(TimeMask, self).__init__(p, domain=domain)
self.n = int_range(n) # pylint: disable=invalid-name
self.size = float_range(size)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
from .flags import FLAGS # pylint: disable=import-outside-toplevel
time_factor = FLAGS.audio_sample_rate / 1000.0 # samples per ms
if self.domain != 'signal':
time_factor /= FLAGS.feature_win_step # windows per ms
time_max = tf.shape(tensor)[0] if self.domain == 'signal' else tf.shape(tensor)[1]
n = tf_pick_value_from_range(self.n, clock=clock)
def body(i, augmented):
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * time_factor, dtype=tf.int32)
size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size))
tf.print(size)
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32)
rest = time_max - t0 - size
if self.domain == 'spectrogram':
fm = tf.shape(tensor)[2]
time_mask = tf.concat([tf.ones([1, t0, fm]), tf.zeros([1, size, fm]), tf.ones([1, rest, fm])], axis=1)
elif self.domain == 'signal':
time_mask = tf.concat([tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0)
else:
time_mask = tf.concat([tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1)
return i + 1, augmented * time_mask
return tf.while_loop(lambda i, augmented: i < n, body, (0, tensor))[1]
class Dropout(GraphAugmentation):
"""See "Dropout augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
super(Dropout, self).__init__(p, domain=domain)
self.rate = float_range(rate)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
rate = tf_pick_value_from_range(self.rate, clock=clock)
rate = tf.math.maximum(0.0, rate)
factors = tf.random.stateless_uniform(tf.shape(tensor),
(clock * tf.int32.min, clock * tf.int32.max),
minval=0.0,
maxval=1.0,
dtype=tf.float32)
return tensor * tf.math.sign(tf.math.floor(factors + rate))
class Add(GraphAugmentation):
"""See "Add augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, domain='features', stddev=5):
super(Add, self).__init__(p, domain=domain)
self.stddev = float_range(stddev)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max)
return tensor + tf.random.stateless_normal(tf.shape(tensor), seed, mean=0.0, stddev=stddev)
class Multiply(GraphAugmentation):
"""See "Multiply augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, domain='features', stddev=5):
super(Multiply, self).__init__(p, domain=domain)
self.stddev = float_range(stddev)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max)
return tensor * tf.random.stateless_normal(tf.shape(tensor), seed, mean=1.0, stddev=stddev)

Просмотреть файл

@ -2,7 +2,6 @@ from __future__ import absolute_import, division, print_function
import os
import sys
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict
@ -13,6 +12,7 @@ from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .text import Alphabet, UTF8Alphabet
from .helpers import parse_file_size
from .augmentations import parse_augmentations
class ConfigSingleton:
_config = None
@ -30,6 +30,17 @@ Config = ConfigSingleton() # pylint: disable=invalid-name
def initialize_globals():
c = AttrDict()
# Augmentations
c.augmentations = parse_augmentations(FLAGS.augment)
if len(c.augmentations) > 0 and FLAGS.feature_cache is not None and FLAGS.cache_for_epochs == 0:
log_warn('Due to current feature-cache settings the exact same sample augmentations of the first '
'epoch will be repeated on all following epochs. This could lead to unintended over-fitting. '
'You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs.')
# Caching
if FLAGS.cache_for_epochs == 1:
log_warn('--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it.')
# Read-buffer
FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)

Просмотреть файл

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
from collections import Counter
from functools import partial
import numpy as np
@ -11,13 +12,13 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio
from .config import Config
from .text import text_to_char_array
from .flags import FLAGS
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from .augmentations import apply_sample_augmentations, apply_graph_augmentations
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
from .sample_collections import samples_from_sources, augment_samples
from .sample_collections import samples_from_sources
from .helpers import remember_exception, MEGABYTE
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmentations=None, sample_id=None):
if train_phase:
# We need the lambdas to make TensorFlow happy.
# pylint: disable=unnecessary-lambda
@ -27,73 +28,48 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
lambda: tf.no_op(),
name='matching_sample_rate')
spectrogram = contrib_audio.audio_spectrogram(samples,
if train_phase and augmentations is not None:
audio = apply_graph_augmentations('signal', audio, augmentations, clock=clock)
spectrogram = contrib_audio.audio_spectrogram(audio,
window_size=Config.audio_window_samples,
stride=Config.audio_step_samples,
magnitude_squared=True)
# Data Augmentations
if train_phase:
if FLAGS.augmentation_spec_dropout_keeprate < 1:
spectrogram = augment_dropout(spectrogram,
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
if train_phase and augmentations is not None:
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, clock=clock)
# sparse warp must before freq/time masking
if FLAGS.augmentation_sparse_warp:
spectrogram = augment_sparse_warp(spectrogram,
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points,
num_control_points=FLAGS.augmentation_sparse_warp_num_control_points)
features = contrib_audio.mfcc(spectrogram=spectrogram,
sample_rate=sample_rate,
dct_coefficient_count=Config.n_input,
upper_frequency_limit=FLAGS.audio_sample_rate / 2)
features = tf.reshape(features, [-1, Config.n_input])
if FLAGS.augmentation_freq_and_time_masking:
spectrogram = augment_freq_time_mask(spectrogram,
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
if train_phase and augmentations is not None:
features = apply_graph_augmentations('features', features, augmentations, clock=clock)
if FLAGS.augmentation_pitch_and_tempo_scaling:
spectrogram = augment_pitch_and_tempo(spectrogram,
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
if FLAGS.augmentation_speed_up_std > 0:
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
mfccs = contrib_audio.mfcc(spectrogram=spectrogram,
sample_rate=sample_rate,
dct_coefficient_count=Config.n_input,
upper_frequency_limit=FLAGS.audio_sample_rate/2)
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
return mfccs, tf.shape(input=mfccs)[0]
return features, tf.shape(input=features)[0]
def audio_to_features(audio, sample_rate, train_phase=False, sample_id=None):
features, features_len = samples_to_mfccs(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
if train_phase:
if FLAGS.data_aug_features_multiplicative > 0:
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
if FLAGS.data_aug_features_additive > 0:
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
return features, features_len
def audiofile_to_features(wav_filename, train_phase=False):
def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentations=None):
samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return audio_to_features(decoded.audio, decoded.sample_rate, train_phase=train_phase, sample_id=wav_filename)
return audio_to_features(decoded.audio,
decoded.sample_rate,
clock=clock,
train_phase=train_phase,
augmentations=augmentations,
sample_id=wav_filename)
def entry_to_features(sample_id, audio, sample_rate, transcript, train_phase=False):
def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None):
# https://bugs.python.org/issue32117
features, features_len = audio_to_features(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
features, features_len = audio_to_features(audio,
sample_rate,
clock=clock,
train_phase=train_phase,
augmentations=augmentations,
sample_id=sample_id)
sparse_transcript = tf.SparseTensor(*transcript)
return sample_id, features, features_len, sparse_transcript
@ -109,25 +85,32 @@ def to_sparse_tuple(sequence):
def create_dataset(sources,
batch_size,
repetitions=1,
augmentation_specs=None,
enable_cache=False,
epochs=1,
augmentations=None,
cache_path=None,
train_phase=False,
exception_box=None,
process_ahead=None,
buffering=1 * MEGABYTE):
epoch_counter = Counter() # survives restarts of the dataset and its generator
def generate_values():
epoch = epoch_counter['epoch']
if train_phase:
epoch_counter['epoch'] += 1
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
samples = augment_samples(samples,
repetitions=repetitions,
augmentation_specs=augmentation_specs,
buffering=buffering,
process_ahead=2 * batch_size if process_ahead is None else process_ahead)
for sample in samples:
num_samples = len(samples)
samples = apply_sample_augmentations(samples,
augmentations,
buffering=buffering,
process_ahead=2 * batch_size if process_ahead is None else process_ahead,
clock=epoch / epochs,
final_clock=(epoch + 1) / epochs)
for sample_index, sample in enumerate(samples):
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
transcript = to_sparse_tuple(transcript)
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript, clock
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
@ -143,13 +126,13 @@ def create_dataset(sources,
sample_ids = sample_ids.batch(batch_size)
return tf.data.Dataset.zip((sample_ids, features, transcripts))
process_fn = partial(entry_to_features, train_phase=train_phase)
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
output_types=(tf.string, tf.float32, tf.int32,
(tf.int64, tf.int32, tf.int64)))
(tf.int64, tf.int32, tf.int64), tf.float64))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
if enable_cache:
if cache_path is not None:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
.prefetch(len(Config.available_devices)))
@ -172,7 +155,7 @@ def split_audio_file(audio_path,
yield time_start, time_end, samples
def to_mfccs(time_start, time_end, samples):
features, features_len = samples_to_mfccs(samples, audio_format.rate)
features, features_len = audio_to_features(samples, audio_format.rate)
return time_start, time_end, features, features_len
def create_batch_set(bs, criteria):

Просмотреть файл

@ -19,6 +19,7 @@ def create_flags():
f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)')
f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.')
f.DEFINE_integer('cache_for_epochs', 0, 'after how many epochs the feature cache is invalidated again - 0 for "never"')
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
@ -28,32 +29,6 @@ def create_flags():
# ================
f.DEFINE_multi_string('augment', None, 'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"')
f.DEFINE_integer('augmentations_per_epoch', 1, 'how often the train set should be repeated and re-augmented per epoch')
f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise')
f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise')
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp. USE OF THIS FLAG IS UNSUPPORTED, enable sparse warp will increase training time drastically, and the paper also mentioned that this is not a major factor to improve accuracy.')
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para')
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
f.DEFINE_boolean('augmentation_pitch_and_tempo_scaling', False, 'whether to use spectrogram speed and tempo scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
# Global Constants
# ================

Просмотреть файл

@ -174,3 +174,18 @@ def pick_value_from_range(value_range, clock=None):
value = value_range.start + clock * (value_range.end - value_range.start)
value = random.uniform(value - value_range.r, value + value_range.r)
return round(value) if isinstance(value_range.start, int) else value
def tf_pick_value_from_range(value_range, clock=None, double_precision=False):
import tensorflow as tf # pylint: disable=import-outside-toplevel
clock = (tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64) if clock is None
else tf.maximum(tf.constant(0.0, dtype=tf.float64), tf.minimum(tf.constant(1.0, dtype=tf.float64), clock)))
value = value_range.start + clock * (value_range.end - value_range.start)
value = tf.random.stateless_uniform([],
minval=value - value_range.r,
maxval=value + value_range.r,
seed=(clock * tf.int32.min, clock * tf.int32.max),
dtype=tf.float64)
if isinstance(value_range.start, int):
return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32)
return tf.cast(value, tf.float64 if double_precision else tf.float32)

Просмотреть файл

@ -2,14 +2,12 @@
import os
import csv
import json
import random
from pathlib import Path
from functools import partial
from .signal_augmentations import parse_augmentation
from .helpers import MEGABYTE, GIGABYTE, Interleaved, LimitingPool
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, AUDIO_TYPE_NP, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension
from .helpers import MEGABYTE, GIGABYTE, Interleaved
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension
BIG_ENDIAN = 'big'
INT_SIZE = 4
@ -416,88 +414,3 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
return Interleaved(*cols, key=lambda s: s.duration)
class PreparationContext:
def __init__(self, target_audio_type, augmentations):
self.target_audio_type = target_audio_type
self.augmentations = augmentations
AUGMENTATION_CONTEXT = None
def _init_augmentation_worker(preparation_context):
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
AUGMENTATION_CONTEXT = preparation_context
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
for augmentation in context.augmentations:
if random.random() < augmentation.probability:
augmentation.apply(sample, clock)
sample.change_audio_type(new_audio_type=context.target_audio_type)
return sample
def augment_samples(samples,
audio_type=AUDIO_TYPE_NP,
augmentation_specs=None,
buffering=BUFFER_SIZE,
process_ahead=None,
repetitions=1,
fixed_clock=None):
"""
Prepares samples for being used during training.
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
Parameters
----------
samples : Sample enumeration
Typically produced by samples_from_sources.
audio_type : str
Target audio-type to convert samples to. See util.audio.Sample.__init__ .
augmentation_specs : list of str
Augmentation specifications like ["reverb[delay=20.0,decay=-20]", "volume"]. See TRAINING.rst.
buffering : int
Read-buffer size to use while reading files.
process_ahead : int
Number of samples to pre-process ahead of time.
repetitions : int
How often the input sample enumeration should get repeated for being re-augmented.
fixed_clock : float
Sets the internal clock to a value between 0.0 (beginning of epoch) and 1.0 (end of epoch).
Setting this to a number is used for simulating augmentations at a certain epoch-time.
If kept at None (default), the internal clock will run regularly from 0.0 to 1.0,
hence preparing them for training.
Returns
-------
iterable of util.sample_collections.LabeledSample or util.audio.Sample
"""
def timed_samples():
for repetition in range(repetitions):
for sample_index, sample in enumerate(samples):
if fixed_clock is None:
yield sample, (repetition * len(samples) + sample_index) / (repetitions * len(samples))
else:
yield sample, fixed_clock
augmentations = [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
try:
for augmentation in augmentations:
augmentation.start(buffering=buffering)
context = PreparationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _augment_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_augmentation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
augmentation.stop()

Просмотреть файл

@ -1,222 +0,0 @@
import os
import re
import math
import random
import numpy as np
from multiprocessing import Queue, Process
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import int_range, float_range, pick_value_from_range, MEGABYTE
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z]+)(\[(?P<params>.*)\])?$')
BUFFER_SIZE = 1 * MEGABYTE
class Augmentation:
def __init__(self, p=1.0):
self.probability = float(p)
def start(self, buffering=BUFFER_SIZE):
pass
def apply(self, sample, clock):
raise NotImplementedError
def stop(self):
pass
def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
"""
As the central distribution point for overlay samples this function is supposed to run in one process only.
This ensures that samples are not used twice if not required.
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
These are then doing decompression, potential conversion and overlaying in parallel.
"""
# preventing cyclic import problems
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
while True:
for sample in samples:
queue.put(sample)
class Overlay(Augmentation):
"""See "Overlay augmentation" in TRAINING.rst"""
def __init__(self, source, p=1.0, snr=3.0, layers=1):
super(Overlay, self).__init__(p)
self.source = source
self.snr = float_range(snr)
self.layers = int_range(layers)
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
self.current_sample = None
self.enqueue_process = None
def start(self, buffering=BUFFER_SIZE):
self.enqueue_process = Process(target=_enqueue_overlay_samples,
args=(self.source, self.queue),
kwargs={'buffering': buffering})
self.enqueue_process.start()
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
n_layers = pick_value_from_range(self.layers, clock=clock)
audio = sample.audio
overlay_data = np.zeros_like(audio)
for _ in range(n_layers):
overlay_offset = 0
while overlay_offset < len(audio):
if self.current_sample is None:
next_overlay_sample = self.queue.get()
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
self.current_sample = next_overlay_sample.audio
n_required = len(audio) - overlay_offset
n_current = len(self.current_sample)
if n_required >= n_current: # take it completely
overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample
overlay_offset += n_current
self.current_sample = None
else: # take required slice from head and keep tail for next layer or sample
overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required]
overlay_offset += n_required
self.current_sample = self.current_sample[n_required:]
snr_db = pick_value_from_range(self.snr, clock=clock)
orig_dbfs = max_dbfs(audio)
overlay_gain = orig_dbfs - max_dbfs(overlay_data) - snr_db
audio += overlay_data * gain_db_to_ratio(overlay_gain)
sample.audio = normalize_audio(audio, dbfs=orig_dbfs)
def stop(self):
if self.enqueue_process is not None:
self.enqueue_process.terminate()
class Reverb(Augmentation):
"""See "Reverb augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, delay=20.0, decay=10.0):
super(Reverb, self).__init__(p)
self.delay = float_range(delay)
self.decay = float_range(decay)
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
audio = np.array(sample.audio, dtype=np.float64)
orig_dbfs = max_dbfs(audio)
delay = pick_value_from_range(self.delay, clock=clock)
decay = pick_value_from_range(self.decay, clock=clock)
decay = gain_db_to_ratio(-decay)
result = np.copy(audio)
primes = [17, 19, 23, 29, 31]
for delay_prime in primes: # primes to minimize comb filter interference
layer = np.copy(audio)
n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0)
n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero
for w_index in range(0, math.floor(len(audio) / n_delay)):
w1 = w_index * n_delay
w2 = (w_index + 1) * n_delay
width = min(len(audio) - w2, n_delay) # last window could be smaller
layer[w2:w2 + width] += decay * layer[w1:w1 + width]
result += layer
audio = normalize_audio(result, dbfs=orig_dbfs)
sample.audio = np.array(audio, dtype=np.float32)
class Resample(Augmentation):
"""See "Resample augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, rate=8000):
super(Resample, self).__init__(p)
self.rate = int_range(rate)
def apply(self, sample, clock):
# late binding librosa and its dependencies
from librosa.core import resample # pylint: disable=import-outside-toplevel
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
rate = pick_value_from_range(self.rate, clock=clock)
audio = sample.audio
orig_len = len(audio)
audio = np.swapaxes(audio, 0, 1)
audio = resample(audio, sample.audio_format.rate, rate)
audio = resample(audio, rate, sample.audio_format.rate)
audio = np.swapaxes(audio, 0, 1)[0:orig_len]
sample.audio = audio
class Codec(Augmentation):
"""See "Codec augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, bitrate=3200):
super(Codec, self).__init__(p)
self.bitrate = int_range(bitrate)
def apply(self, sample, clock):
bitrate = pick_value_from_range(self.bitrate, clock=clock)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream
class Gaps(Augmentation):
"""See "Gaps augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, n=1, size=50.0):
super(Gaps, self).__init__(p)
self.n_gaps = int_range(n)
self.size = float_range(size)
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
audio = sample.audio
n_gaps = pick_value_from_range(self.n_gaps, clock=clock)
for _ in range(n_gaps):
size = pick_value_from_range(self.size, clock=clock)
size = int(size * sample.audio_format.rate / 1000.0)
size = min(size, len(audio) // 10) # a gap should never exceed 10 percent of the audio
offset = random.randint(0, max(0, len(audio) - size - 1))
audio[offset:offset + size] = 0
sample.audio = audio
class Volume(Augmentation):
"""See "Volume augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, dbfs=3.0103):
super(Volume, self).__init__(p)
self.target_dbfs = float_range(dbfs)
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
def parse_augmentation(augmentation_spec):
"""
Parses an augmentation specification.
Parameters
----------
augmentation_spec : str
Augmentation specification like "reverb[delay=20.0,decay=-20]".
Returns
-------
Instance of an augmentation class from util.signal_augmentations.*.
"""
match = SPEC_PARSER.match(augmentation_spec)
if not match:
raise ValueError('Augmentation specification has wrong format')
cls_name = match.group('cls')
cls_name = cls_name[0].upper() + cls_name[1:]
augmentation_cls = globals()[cls_name] if cls_name in globals() else None
if not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation:
raise ValueError('Unknown augmentation: {}'.format(cls_name))
parameters = match.group('params')
parameters = [] if parameters is None else parameters.split(',')
args = []
kwargs = {}
for parameter in parameters:
pair = tuple(list(map(str.strip, (parameter.split('=')))))
if len(pair) == 1:
args.append(pair)
elif len(pair) == 2:
kwargs[pair[0]] = pair[1]
else:
raise ValueError('Unable to parse augmentation value assignment')
return augmentation_cls(*args, **kwargs)

Просмотреть файл

@ -1,127 +0,0 @@
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from .sparse_image_warp import sparse_image_warp
def augment_freq_time_mask(spectrogram,
frequency_masking_para=30,
time_masking_para=10,
frequency_mask_num=3,
time_mask_num=3):
time_max = tf.shape(spectrogram)[1]
freq_max = tf.shape(spectrogram)[2]
# Frequency masking
for _ in range(frequency_mask_num):
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
value_ones_freq_prev = tf.ones(shape=[1, time_max, f0])
value_zeros_freq = tf.zeros(shape=[1, time_max, f])
value_ones_freq_next = tf.ones(shape=[1, time_max, freq_max-(f0+f)])
freq_mask = tf.concat([value_ones_freq_prev, value_zeros_freq, value_ones_freq_next], axis=2)
# mel_spectrogram[:, f0:f0 + f, :] = 0 #can't assign to tensor
# mel_spectrogram[:, f0:f0 + f, :] = value_zeros_freq #can't assign to tensor
spectrogram = spectrogram*freq_mask
# Time masking
for _ in range(time_mask_num):
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)
value_zeros_time_prev = tf.ones(shape=[1, t0, freq_max])
value_zeros_time = tf.zeros(shape=[1, t, freq_max])
value_zeros_time_next = tf.ones(shape=[1, time_max-(t0+t), freq_max])
time_mask = tf.concat([value_zeros_time_prev, value_zeros_time, value_zeros_time_next], axis=1)
# mel_spectrogram[:, :, t0:t0 + t] = 0 #can't assign to tensor
# mel_spectrogram[:, :, t0:t0 + t] = value_zeros_time #can't assign to tensor
spectrogram = spectrogram*time_mask
return spectrogram
def augment_pitch_and_tempo(spectrogram,
max_tempo=1.2,
max_pitch=1.1,
min_pitch=0.95):
original_shape = tf.shape(spectrogram)
choosen_pitch = tf.random.uniform(shape=(), minval=min_pitch, maxval=max_pitch)
choosen_tempo = tf.random.uniform(shape=(), minval=1, maxval=max_tempo)
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32)*choosen_pitch, tf.int32)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32)/(choosen_tempo), tf.int32)
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_time_size, new_freq_size])
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=tf.shape(spectrogram_aug)[1], target_width=tf.minimum(original_shape[2], new_freq_size))
spectrogram_aug = tf.cond(choosen_pitch < 1,
lambda: tf.image.pad_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0,
target_height=tf.shape(spectrogram_aug)[1], target_width=original_shape[2]),
lambda: spectrogram_aug)
return spectrogram_aug[:, :, :, 0]
def augment_speed_up(spectrogram,
speed_std=0.1):
original_shape = tf.shape(spectrogram)
choosen_speed = tf.math.abs(tf.random.normal(shape=(), stddev=speed_std)) # abs makes sure the augmention will only speed up
choosen_speed = 1 + choosen_speed
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32), tf.int32)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32)/(choosen_speed), tf.int32)
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_time_size, new_freq_size])
return spectrogram_aug[:, :, :, 0]
def augment_dropout(spectrogram,
keep_prob=0.95):
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
Args:
spectrogram: `[batch, time, frequency]` float `Tensor`
time_warping_para: 'W' parameter in paper
interpolation_order: used to put into `sparse_image_warp`
regularization_weight: used to put into `sparse_image_warp`
num_boundary_points: used to put into `sparse_image_warp`,
default=1 means boundary points on 4 corners of the image
num_control_points: number of control points
Returns:
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
type as input image.
"""
# reshape to fit `sparse_image_warp`'s input shape
# (1, time steps, freq, 1), batch_size must be 1
spectrogram = tf.expand_dims(spectrogram, -1)
original_shape = tf.shape(spectrogram)
tau, freq_size = original_shape[1], original_shape[2]
# to protect short audio
time_warping_para = tf.math.minimum(
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
# don't choose boundary frequency
choosen_freqs = tf.random.shuffle(
tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
source_max = tau - time_warping_para
source_min = tf.math.minimum(source_max - num_control_points, time_warping_para)
choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points]
dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32)
sources = []
dests = []
for i in range(num_control_points):
# generate source points `t` of time axis between (W, tau-W)
rand_source_time = choosen_times[i]
rand_dest_time = rand_source_time + dest_time_widths[i]
choosen_freq = choosen_freqs[i]
sources.append([rand_source_time, choosen_freq])
dests.append([rand_dest_time, choosen_freq])
source_control_point_locations = tf.cast([sources], tf.float32)
dest_control_point_locations = tf.cast([dests], tf.float32)
warped_spectrogram, _ = sparse_image_warp(spectrogram,
source_control_point_locations=source_control_point_locations,
dest_control_point_locations=dest_control_point_locations,
interpolation_order=interpolation_order,
regularization_weight=regularization_weight,
num_boundary_points=num_boundary_points)
return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))