зеркало из https://github.com/mozilla/DeepSpeech.git
Refactoring of TF based augmentations
This commit is contained in:
Родитель
bfaa68945a
Коммит
d94db7ca43
16
bin/play.py
16
bin/play.py
|
@ -10,7 +10,8 @@ import random
|
|||
import argparse
|
||||
|
||||
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
|
||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples
|
||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source
|
||||
from deepspeech_training.util.augmentations import parse_augmentations, apply_sample_augmentations, SampleAugmentation
|
||||
|
||||
|
||||
def get_samples_in_play_order():
|
||||
|
@ -38,12 +39,15 @@ def get_samples_in_play_order():
|
|||
|
||||
|
||||
def play_collection():
|
||||
augmentations = parse_augmentations(CLI_ARGS.augment)
|
||||
if any(map(lambda a: not isinstance(a, SampleAugmentation), augmentations)):
|
||||
print("Warning: Some of the augmentations cannot be simulated by this command.")
|
||||
samples = get_samples_in_play_order()
|
||||
samples = augment_samples(samples,
|
||||
audio_type=AUDIO_TYPE_PCM,
|
||||
augmentation_specs=CLI_ARGS.augment,
|
||||
process_ahead=0,
|
||||
fixed_clock=CLI_ARGS.clock)
|
||||
samples = apply_sample_augmentations(samples,
|
||||
audio_type=AUDIO_TYPE_PCM,
|
||||
augmentations=augmentations,
|
||||
process_ahead=0,
|
||||
clock=CLI_ARGS.clock)
|
||||
for sample in samples:
|
||||
if not CLI_ARGS.quiet:
|
||||
print('Sample "{}"'.format(sample.sample_id), file=sys.stderr)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/smoke_test"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
# Force only one visible device because we have a single-sample dataset
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--augment speed \
|
||||
--augment dropout \
|
||||
--augment pitch_and_tempo \
|
||||
--augment time_mask \
|
||||
--augment frequency_mask \
|
||||
--augment add \
|
||||
--augment multiply \
|
||||
--augment warp \
|
||||
--n_hidden 100 \
|
||||
--epochs 1
|
|
@ -41,12 +41,6 @@ if ! $compare --if-differ "${ldc93s1_wav}" /tmp/reverb-test.wav; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
$play ${ldc93s1_wav} --augment gaps[n=10,size=100.0] --pipe >/tmp/gaps-test.wav
|
||||
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/gaps-test.wav; then
|
||||
echo "Gaps augmentation had no effect or changed basic sample properties"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
$play ${ldc93s1_wav} --augment resample[rate=4000] --pipe >/tmp/resample-test.wav
|
||||
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/resample-test.wav; then
|
||||
echo "Resample augmentation had no effect or changed basic sample properties"
|
184
doc/TRAINING.rst
184
doc/TRAINING.rst
|
@ -270,12 +270,6 @@ Augmentation
|
|||
|
||||
Augmentation is a useful technique for better generalization of machine learning models. Thus, a pre-processing pipeline with various augmentation techniques on raw pcm and spectrogram has been implemented and can be used while training the model. Following are the available augmentation techniques that can be enabled at training time by using the corresponding flags in the command line.
|
||||
|
||||
|
||||
Audio Augmentation
|
||||
------------------
|
||||
|
||||
Augmentations that are applied before potential feature caching can be specified through the ``--augment`` flag. Being a multi-flag, it can be specified multiple times (see below for an example).
|
||||
|
||||
Each sample of the training data will get treated by every specified augmentation in their given order. However: whether an augmentation will actually get applied to a sample is decided by chance on base of the augmentation's probability value. For example a value of ``p=0.1`` would apply the according augmentation to just 10% of all samples. This also means that augmentations are not mutually exclusive on a per-sample basis.
|
||||
|
||||
The ``--augment`` flag uses a common syntax for all augmentation types:
|
||||
|
@ -297,14 +291,31 @@ In the documentation below, whenever a value is specified as ``<float-range>`` o
|
|||
|
||||
* ``<value>~<r>``: A center value with a randomization radius around it. E.g. ``1.2~0.4`` will result in picking of a uniformly random value between 0.8 and 1.6 on each sample augmentation.
|
||||
|
||||
* ``<start>:<end>``: The value will range from `<start>` at the beginning of an epoch to `<end>` at the end of an epoch. E.g. ``-0.2:1.2`` (float) or ``2000:4000`` (int)
|
||||
* ``<start>:<end>``: The value will range from `<start>` at the beginning of the training to `<end>` at the end of the training. E.g. ``-0.2:1.2`` (float) or ``2000:4000`` (int)
|
||||
|
||||
* ``<start>:<end>~<r>``: Combination of the two previous cases with a ranging center value. E.g. ``4-6~2`` would at the beginning of an epoch pick values between 2 and 6 and at the end of an epoch between 4 and 8.
|
||||
* ``<start>:<end>~<r>``: Combination of the two previous cases with a ranging center value. E.g. ``4-6~2`` would at the beginning of the training pick values between 2 and 6 and at the end of the training between 4 and 8.
|
||||
|
||||
Ranges specified with integer limits will only assume integer (rounded) values.
|
||||
|
||||
If feature caching is enabled, these augmentations will only be performed on the first epoch and the result will be reused for subsequent epochs. The flag ``--augmentations_per_epoch N`` (by default `N` is 1) could be used to get more than one epoch worth of augmentations into the cache. During training, each epoch will do ``N`` passes over the training set, each time performing augmentation independently of previous passes. Be aware: this will also multiply the required size of the feature cache if it's enabled.
|
||||
.. warning::
|
||||
If feature caching is enabled and infinite (default), these augmentations will only be performed on first epoch and the result will be reused for subsequent epochs. This would not only hinder value ranges from reaching their intended final values, but could also lead to unintended over-fitting. In this case flag ``--cache_for_epochs N`` (with N > 1) should be used to periodically invalidate the cache and thus allow samples to be re-augmented in new ways and with current range-values.
|
||||
|
||||
Every augmentation is targeting a certain data representation of the sample - further on called *domain*.
|
||||
Augmentations are applied domain-wise in the following order:
|
||||
|
||||
1. **sample** domain: The sample just got loaded and its waveform is represented as a NumPy array. For implementation reasons these augmentations are the only ones that can be "simulated" through ``bin/play.py``.
|
||||
|
||||
2. **signal** domain: The sample waveform is represented as a tensor.
|
||||
|
||||
3. **spectrogram** domain: The sample spectrogram is represented as a tensor.
|
||||
|
||||
4. **features** domain: The sample's MEL spectrogram features are represented as a tensor.
|
||||
|
||||
During each phase augmentations are applied in command-line order (the **warp** augmentation being the only exception).
|
||||
|
||||
|
||||
Sample domain augmentations
|
||||
---------------------------
|
||||
|
||||
**Overlay augmentation** ``--augment overlay[p=<float>,source=<str>,snr=<float-range>,layers=<int-range>]``
|
||||
Layers another audio source (multiple times) onto augmented samples.
|
||||
|
@ -328,16 +339,6 @@ If feature caching is enabled, these augmentations will only be performed on the
|
|||
* **decay**: sound decay in dB per reflection - higher values will result in a less reflective perceived "room"
|
||||
|
||||
|
||||
**Gaps augmentation** ``--augment gaps[p=<float>,n=<int-range>,size=<float-range>]``
|
||||
Sets time-intervals within the augmented samples to zero (silence) at random positions.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **n**: number of intervals to set to zero
|
||||
|
||||
* **size**: duration of intervals in ms
|
||||
|
||||
|
||||
**Resample augmentation** ``--augment resample[p=<float>,rate=<int-range>]``
|
||||
Resamples augmented samples to another sample rate and then resamples back to the original sample rate.
|
||||
|
||||
|
@ -361,6 +362,96 @@ If feature caching is enabled, these augmentations will only be performed on the
|
|||
|
||||
* **dbfs** : target volume in dBFS (default value of 3.0103 will normalize min and max amplitudes to -1.0/1.0)
|
||||
|
||||
Spectrogram domain augmentations
|
||||
--------------------------------
|
||||
|
||||
**Pitch and tempo augmentation** ``--augment pitch_and_tempo[p=<float>,pitch=<float-range>,tempo=<float-range>]``
|
||||
Scales spectrogram on time and frequency axis and thus changes pitch and playback tempo.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **pitch**: pitch factor by with the frequency axis is scaled (e.g. a value of 2.0 will raise audio frequency by one octave)
|
||||
|
||||
* **tempo**: tempo factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
|
||||
|
||||
|
||||
**Speed augmentation** ``--augment speed[p=<float>,factor=<float-range>]``
|
||||
Scales spectrogram on time axis and thus changes playback tempo.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **factor**: speed factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
|
||||
|
||||
|
||||
**Warp augmentation** ``--augment warp[p=<float>,shift=<float-range>,order=<int-range>,nbp=<int-range>,ncp=<int-range>,regularization_weight=<float>]``
|
||||
Applies a non-linear image warp to the spectrogram, where the warp is specified by the source and destination locations of a (potentially small) number of control points. Of all specified spectrogram augmentations this one will always be applied first.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **shift**: maximum shift distance of control points on time axis in ms
|
||||
|
||||
* **order**: polynomial order used by the spline interpolation
|
||||
|
||||
* **nbp**: how many zero-flow boundary points to include at each spectrogram edge
|
||||
|
||||
* **ncp**: how many control points to warp inside the spectrogram
|
||||
|
||||
* **regularization_weight**: weight on smoothness regularizer in interpolation
|
||||
|
||||
|
||||
**Frequency mask augmentation** ``--augment frequency_mask[p=<float>,n=<int-range>,size=<int-range>]``
|
||||
Sets frequency-intervals within the augmented samples to zero (silence) at random frequencies.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **n**: number of intervals to mask
|
||||
|
||||
* **size**: number of frequency bands to mask per interval
|
||||
|
||||
Multi domain augmentations
|
||||
--------------------------
|
||||
|
||||
**Time mask augmentation** ``--augment time_mask[p=<float>,n=<int-range>,size=<float-range>,domain=<domain>]``
|
||||
Sets time-intervals within the augmented samples to zero (silence) at random positions.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **n**: number of intervals to set to zero
|
||||
|
||||
* **size**: duration of intervals in ms
|
||||
|
||||
* **domain**: data representation to apply augmentation to - "signal", "features" or "spectrogram" (default)
|
||||
|
||||
|
||||
**Dropout augmentation** ``--augment dropout[p=<float>,rate=<float-range>,domain=<domain>]``
|
||||
Zeros random data points of the targeted data representation.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **rate**: dropout rate ranging from 0.0 for no dropout to 1.0 for 100% dropout
|
||||
|
||||
* **domain**: data representation to apply augmentation to - "signal", "features" or "spectrogram" (default)
|
||||
|
||||
|
||||
**Add augmentation** ``--augment add[p=<float>,stddev=<float-range>,domain=<domain>]``
|
||||
Adds random values picked from a normal distribution (with a mean of 0.0) to all data points of the targeted data representation.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **stddev**: standard deviation of the normal distribution to pick values from
|
||||
|
||||
* **domain**: data representation to apply augmentation to - "signal", "features" (default) or "spectrogram"
|
||||
|
||||
|
||||
**Multiply augmentation** ``--augment multiply[p=<float>,stddev=<float-range>,domain=<domain>]``
|
||||
Multiplies all data points of the targeted data representation with random values picked from a normal distribution (with a mean of 1.0).
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **stddev**: standard deviation of the normal distribution to pick values from
|
||||
|
||||
* **domain**: data representation to apply augmentation to - "signal", "features" (default) or "spectrogram"
|
||||
|
||||
|
||||
Example training with all augmentations:
|
||||
|
||||
|
@ -368,18 +459,26 @@ Example training with all augmentations:
|
|||
|
||||
python -u DeepSpeech.py \
|
||||
--train_files "train.sdb" \
|
||||
--augmentations_per_epoch 10 \
|
||||
--feature_cache ./feature.cache \
|
||||
--cache_for_epochs 10 \
|
||||
--epochs 100 \
|
||||
--augment overlay[p=0.5,source=noise.sdb,layers=1,snr=50:20~10] \
|
||||
--augment overlay[p=0.2,source=voices.sdb,layers=10:6,snr=50:20~10] \
|
||||
--augment reverb[p=0.1,delay=50.0~30.0,decay=10.0:2.0~1.0] \
|
||||
--augment gaps[p=0.05,n=1:3~2,size=10:100] \
|
||||
--augment resample[p=0.1,rate=12000:8000~4000] \
|
||||
--augment codec[p=0.1,bitrate=48000:16000] \
|
||||
--augment volume[p=0.1,dbfs=-10:-40] \
|
||||
--augment pitch_and_tempo[p=0.1,pitch=1~0.2,tempo=1~0.2] \
|
||||
--augment speed[p=0.1,factor=1~0.5] \
|
||||
--augment warp[p=0.1,shift=30:60~20,ncp=4~3] \
|
||||
--augment frequency_mask[p=0.1,n=1:3,size=1:5] \
|
||||
--augment time_mask[p=0.1,domain=signal,n=3:10~2,size=50:100~40] \
|
||||
--augment dropout[p=0.1,rate=0.05] \
|
||||
--augment add[p=0.1,domain=signal,stddev=0~0.5] \
|
||||
--augment multiply[p=0.1,domain=features,stddev=0~0.5] \
|
||||
[...]
|
||||
|
||||
|
||||
The ``bin/play.py`` tool also supports ``--augment`` parameters and can be used for experimenting with different configurations.
|
||||
The ``bin/play.py`` tool also supports ``--augment`` parameters (for sample domain augmentations) and can be used for experimenting with different configurations.
|
||||
|
||||
Example of playing all samples with reverberation and maximized volume:
|
||||
|
||||
|
@ -393,42 +492,3 @@ Example simulation of the codec augmentation of a wav-file first at the beginnin
|
|||
|
||||
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 0.0 test.wav
|
||||
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 1.0 test.wav
|
||||
|
||||
|
||||
The following augmentations are applied after feature caching, hence the way they are applied will not repeat epoch-wise.
|
||||
Working on spectrogram and feature level, `bin/play.py` offers no ability to simulate them.
|
||||
|
||||
#. **Standard deviation for Gaussian additive noise:** ``--data_aug_features_additive``
|
||||
#. **Standard deviation for Normal distribution around 1 for multiplicative noise:** ``--data_aug_features_multiplicative``
|
||||
#. **Standard deviation for speeding-up tempo. If Standard deviation is 0, this augmentation is not performed:** ``--augmentation_speed_up_std``
|
||||
|
||||
Spectrogram Augmentation
|
||||
------------------------
|
||||
|
||||
Inspired by Google Paper on `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition <https://arxiv.org/abs/1904.08779>`_
|
||||
|
||||
|
||||
#.
|
||||
**Keep rate of dropout augmentation on a spectrogram (if 1, no dropout will be performed on the spectrogram)**\ :
|
||||
|
||||
|
||||
* Keep Rate : ``--augmentation_spec_dropout_keeprate value between range [0 - 1]``
|
||||
|
||||
#.
|
||||
**Whether to use frequency and time masking augmentation:**
|
||||
|
||||
|
||||
* Enable / Disable : ``--augmentation_freq_and_time_masking / --noaugmentation_freq_and_time_masking``
|
||||
* Max range of masks in the frequency domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_freq_mask_range eg: 5``
|
||||
* Number of masks in the frequency domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_number_freq_masks eg: 3``
|
||||
* Max range of masks in the time domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_time_mask_range eg: 2``
|
||||
* Number of masks in the time domain when performing freqtime-mask augmentation: ``--augmentation_freq_and_time_masking_number_time_masks eg: 3``
|
||||
|
||||
#.
|
||||
**Whether to use spectrogram speed and tempo scaling:**
|
||||
|
||||
|
||||
* Enable / Disable : ``--augmentation_pitch_and_tempo_scaling / --noaugmentation_pitch_and_tempo_scaling``
|
||||
* Min value of pitch scaling: ``--augmentation_pitch_and_tempo_scaling_min_pitch eg:0.95``
|
||||
* Max value of pitch scaling: ``--augmentation_pitch_and_tempo_scaling_max_pitch eg:1.2``
|
||||
* Max value of tempo scaling: ``--augmentation_pitch_and_tempo_scaling_max_tempo eg:1.2``
|
||||
|
|
|
@ -22,7 +22,8 @@ popd
|
|||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
time ./bin/run-tc-signal_augmentations.sh
|
||||
time ./bin/run-tc-sample_augmentations.sh
|
||||
time ./bin/run-tc-graph_augmentations.sh
|
||||
popd
|
||||
|
||||
virtualenv_deactivate "${pyalias}" "deepspeech"
|
|
@ -6,7 +6,7 @@ build:
|
|||
>
|
||||
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-signal_augmentation-tests.sh 3.6.10:m"
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-augmentation-tests.sh 3.6.10:m"
|
||||
metadata:
|
||||
name: "DeepSpeech Linux AMD64 CPU signal augmentations Py3.6"
|
||||
description: "Augmenting LDC93S1 sample in different ways for Linux/AMD64 16kHz Python 3.6, CPU only, optimized version"
|
|
@ -10,7 +10,6 @@ DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.a
|
|||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
|
@ -32,7 +31,7 @@ from six.moves import zip, range
|
|||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
||||
from .util.evaluate_tools import save_samples_json
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||
|
@ -407,26 +406,13 @@ def log_grads_and_vars(grads_and_vars):
|
|||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
repetitions=FLAGS.augmentations_per_epoch,
|
||||
augmentation_specs=FLAGS.augment,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
epochs=FLAGS.epochs,
|
||||
augmentations=Config.augmentations,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
|
@ -541,6 +527,12 @@ def train():
|
|||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache is not None:
|
||||
feature_cache_index = FLAGS.feature_cache + '.index'
|
||||
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
|
||||
log_info('Invalidating feature cache')
|
||||
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
|
@ -567,11 +559,6 @@ def train():
|
|||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
@ -680,7 +667,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
|
|
|
@ -0,0 +1,564 @@
|
|||
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from multiprocessing import Queue, Process
|
||||
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
|
||||
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
|
||||
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
|
||||
|
||||
|
||||
class Augmentation:
|
||||
def __init__(self, p=1.0):
|
||||
self.probability = float(p)
|
||||
|
||||
|
||||
class SampleAugmentation(Augmentation):
|
||||
def start(self, buffering=BUFFER_SIZE):
|
||||
pass
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
|
||||
class GraphAugmentation(Augmentation):
|
||||
def __init__(self, p=1.0, domain='spectrogram'):
|
||||
super(GraphAugmentation, self).__init__(p)
|
||||
if domain not in ['signal', 'spectrogram', 'features']:
|
||||
raise ValueError('Unsupported augmentation domain: {}'.format(domain))
|
||||
self.domain = domain
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_with_probability(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max))
|
||||
return tf.cond(tf.less(rv, self.probability),
|
||||
lambda: self.apply(tensor, clock=clock),
|
||||
lambda: tensor)
|
||||
|
||||
def maybe_apply(self, domain, tensor, clock=0.0):
|
||||
if domain == self.domain:
|
||||
return self.apply_with_probability(tensor, clock=clock)
|
||||
return tensor
|
||||
|
||||
|
||||
def parse_augmentation(augmentation_spec):
|
||||
"""
|
||||
Parses an augmentation specification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
augmentation_spec : str
|
||||
Augmentation specification like "reverb[delay=20.0,decay=1.0]".
|
||||
|
||||
Returns
|
||||
-------
|
||||
Instance of an augmentation class from util.augmentations.*.
|
||||
"""
|
||||
match = SPEC_PARSER.match(augmentation_spec)
|
||||
if not match:
|
||||
raise ValueError('Augmentation specification has wrong format')
|
||||
cls_name = ''.join(map(lambda p: p[0].upper() + p[1:], match.group('cls').split('_')))
|
||||
augmentation_cls = globals()[cls_name] if cls_name in globals() else None
|
||||
if augmentation_cls is None or not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation:
|
||||
raise ValueError('Unknown augmentation: {}'.format(cls_name))
|
||||
parameters = match.group('params')
|
||||
parameters = [] if parameters is None else parameters.split(',')
|
||||
args = []
|
||||
kwargs = {}
|
||||
for parameter in parameters:
|
||||
pair = tuple(list(map(str.strip, (parameter.split('=')))))
|
||||
if len(pair) == 1:
|
||||
args.append(pair)
|
||||
elif len(pair) == 2:
|
||||
kwargs[pair[0]] = pair[1]
|
||||
else:
|
||||
raise ValueError('Unable to parse augmentation value assignment')
|
||||
return augmentation_cls(*args, **kwargs)
|
||||
|
||||
|
||||
def parse_augmentations(augmentation_specs):
|
||||
"""
|
||||
Parses an augmentation specification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
augmentation_specs : list of str
|
||||
List of augmentation specifications like ["reverb[delay=20.0,decay=1.0]", "volume"].
|
||||
|
||||
Returns
|
||||
-------
|
||||
List of augmentation class instances from util.augmentations.*.
|
||||
"""
|
||||
return [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
|
||||
|
||||
|
||||
def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
|
||||
"""
|
||||
Augments training sample tensor of a certain domain with matching augmentations of passed list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
domain : str
|
||||
Domain of the tensor to apply augmentations to. One of "signal", "spectrogram" or "features"
|
||||
tensor : Tensor of type float32
|
||||
Tensor to apply augmentations to.
|
||||
augmentations : list of augmentation class instances from util.augmentations.*.
|
||||
List of augmentations of which only the spectrogram ones will get applied to the samples.
|
||||
clock : Tensor of type float32
|
||||
Time indicator for augmentation value-ranges. Running from 0.0 (start of training) to 1.0 (end of training).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor of type float32
|
||||
The augmented spectrogram
|
||||
"""
|
||||
if augmentations is not None:
|
||||
# Warp has to come before any spectrogram masking
|
||||
for augmentation in augmentations:
|
||||
if isinstance(augmentation, Warp):
|
||||
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
|
||||
for augmentation in augmentations:
|
||||
if isinstance(augmentation, GraphAugmentation) and not isinstance(augmentation, Warp):
|
||||
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
|
||||
return tensor
|
||||
|
||||
|
||||
class AugmentationContext:
|
||||
def __init__(self, target_audio_type, augmentations):
|
||||
self.target_audio_type = target_audio_type
|
||||
self.augmentations = augmentations
|
||||
|
||||
|
||||
AUGMENTATION_CONTEXT = None
|
||||
|
||||
|
||||
def _init_augmentation_worker(preparation_context):
|
||||
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
|
||||
AUGMENTATION_CONTEXT = preparation_context
|
||||
|
||||
|
||||
def _augment_sample(timed_sample, context=None):
|
||||
context = AUGMENTATION_CONTEXT if context is None else context
|
||||
sample, clock = timed_sample
|
||||
for augmentation in context.augmentations:
|
||||
if random.random() < augmentation.probability:
|
||||
augmentation.apply(sample, clock)
|
||||
sample.change_audio_type(new_audio_type=context.target_audio_type)
|
||||
return sample
|
||||
|
||||
|
||||
def apply_sample_augmentations(samples,
|
||||
augmentations,
|
||||
audio_type=AUDIO_TYPE_NP,
|
||||
buffering=BUFFER_SIZE,
|
||||
process_ahead=None,
|
||||
clock=0.0,
|
||||
final_clock=None):
|
||||
"""
|
||||
Prepares samples for being used during training.
|
||||
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samples : Sample enumeration
|
||||
Typically produced by util.sample_collections.samples_from_sources.
|
||||
augmentations : list of augmentation class instances from util.augmentations.*.
|
||||
List of augmentations of which only the signal ones will get applied to the samples.
|
||||
audio_type : str
|
||||
Target audio-type to convert samples to. See util.audio.Sample.__init__ .
|
||||
buffering : int
|
||||
Read-buffer size to use while reading files.
|
||||
process_ahead : int
|
||||
Number of samples to pre-process ahead of time.
|
||||
clock : float
|
||||
Start or fixed clock value between 0.0 and 1.0 for the first or all samples. Has to be <= than clock_to.
|
||||
final_clock : float
|
||||
Final clock value between 0.0 and 1.0 for the last sample. Has to be >= than clock.
|
||||
Requires samples.__len__ attribute.
|
||||
|
||||
Returns
|
||||
-------
|
||||
iterable of util.sample_collections.LabeledSample or util.audio.Sample
|
||||
"""
|
||||
def timed_samples():
|
||||
if final_clock is None:
|
||||
for sample in samples:
|
||||
yield sample, clock
|
||||
else:
|
||||
for sample_index, sample in enumerate(samples):
|
||||
sample_clock = clock + (final_clock - clock) * (sample_index / len(samples))
|
||||
yield sample, sample_clock
|
||||
|
||||
assert 0.0 <= clock <= 1.0
|
||||
if final_clock is not None:
|
||||
assert 0.0 <= final_clock <= 1.0
|
||||
assert clock <= final_clock
|
||||
augmentations = list(filter(lambda aug: isinstance(aug, SampleAugmentation), augmentations))
|
||||
try:
|
||||
for augmentation in augmentations:
|
||||
augmentation.start(buffering=buffering)
|
||||
context = AugmentationContext(audio_type, augmentations)
|
||||
if process_ahead == 0:
|
||||
for timed_sample in timed_samples():
|
||||
yield _augment_sample(timed_sample, context=context)
|
||||
else:
|
||||
with LimitingPool(process_ahead=process_ahead,
|
||||
initializer=_init_augmentation_worker,
|
||||
initargs=(context,)) as pool:
|
||||
yield from pool.imap(_augment_sample, timed_samples())
|
||||
finally:
|
||||
for augmentation in augmentations:
|
||||
augmentation.stop()
|
||||
|
||||
|
||||
def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
|
||||
"""
|
||||
As the central distribution point for overlay samples this function is supposed to run in one process only.
|
||||
This ensures that samples are not used twice if not required.
|
||||
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
|
||||
These are then doing decompression, potential conversion and overlaying in parallel.
|
||||
"""
|
||||
# preventing cyclic import problems
|
||||
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
|
||||
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
|
||||
while True:
|
||||
for sample in samples:
|
||||
queue.put(sample)
|
||||
|
||||
|
||||
class Overlay(SampleAugmentation):
|
||||
"""See "Overlay augmentation" in TRAINING.rst"""
|
||||
def __init__(self, source, p=1.0, snr=3.0, layers=1):
|
||||
super(Overlay, self).__init__(p)
|
||||
self.source = source
|
||||
self.snr = float_range(snr)
|
||||
self.layers = int_range(layers)
|
||||
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
|
||||
self.current_sample = None
|
||||
self.enqueue_process = None
|
||||
|
||||
def start(self, buffering=BUFFER_SIZE):
|
||||
self.enqueue_process = Process(target=_enqueue_overlay_samples,
|
||||
args=(self.source, self.queue),
|
||||
kwargs={'buffering': buffering})
|
||||
self.enqueue_process.start()
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
n_layers = pick_value_from_range(self.layers, clock=clock)
|
||||
audio = sample.audio
|
||||
overlay_data = np.zeros_like(audio)
|
||||
for _ in range(n_layers):
|
||||
overlay_offset = 0
|
||||
while overlay_offset < len(audio):
|
||||
if self.current_sample is None:
|
||||
next_overlay_sample = self.queue.get()
|
||||
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
self.current_sample = next_overlay_sample.audio
|
||||
n_required = len(audio) - overlay_offset
|
||||
n_current = len(self.current_sample)
|
||||
if n_required >= n_current: # take it completely
|
||||
overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample
|
||||
overlay_offset += n_current
|
||||
self.current_sample = None
|
||||
else: # take required slice from head and keep tail for next layer or sample
|
||||
overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required]
|
||||
overlay_offset += n_required
|
||||
self.current_sample = self.current_sample[n_required:]
|
||||
snr_db = pick_value_from_range(self.snr, clock=clock)
|
||||
orig_dbfs = max_dbfs(audio)
|
||||
overlay_gain = orig_dbfs - max_dbfs(overlay_data) - snr_db
|
||||
audio += overlay_data * gain_db_to_ratio(overlay_gain)
|
||||
sample.audio = normalize_audio(audio, dbfs=orig_dbfs)
|
||||
|
||||
def stop(self):
|
||||
if self.enqueue_process is not None:
|
||||
self.enqueue_process.terminate()
|
||||
|
||||
|
||||
class Codec(SampleAugmentation):
|
||||
"""See "Codec augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, bitrate=3200):
|
||||
super(Codec, self).__init__(p)
|
||||
self.bitrate = int_range(bitrate)
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
bitrate = pick_value_from_range(self.bitrate, clock=clock)
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream
|
||||
|
||||
|
||||
class Reverb(SampleAugmentation):
|
||||
"""See "Reverb augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, delay=20.0, decay=10.0):
|
||||
super(Reverb, self).__init__(p)
|
||||
self.delay = float_range(delay)
|
||||
self.decay = float_range(decay)
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
audio = np.array(sample.audio, dtype=np.float64)
|
||||
orig_dbfs = max_dbfs(audio)
|
||||
delay = pick_value_from_range(self.delay, clock=clock)
|
||||
decay = pick_value_from_range(self.decay, clock=clock)
|
||||
decay = gain_db_to_ratio(-decay)
|
||||
result = np.copy(audio)
|
||||
primes = [17, 19, 23, 29, 31]
|
||||
for delay_prime in primes: # primes to minimize comb filter interference
|
||||
layer = np.copy(audio)
|
||||
n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0)
|
||||
n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero
|
||||
for w_index in range(0, math.floor(len(audio) / n_delay)):
|
||||
w1 = w_index * n_delay
|
||||
w2 = (w_index + 1) * n_delay
|
||||
width = min(len(audio) - w2, n_delay) # last window could be smaller
|
||||
layer[w2:w2 + width] += decay * layer[w1:w1 + width]
|
||||
result += layer
|
||||
audio = normalize_audio(result, dbfs=orig_dbfs)
|
||||
sample.audio = np.array(audio, dtype=np.float32)
|
||||
|
||||
|
||||
class Resample(SampleAugmentation):
|
||||
"""See "Resample augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, rate=8000):
|
||||
super(Resample, self).__init__(p)
|
||||
self.rate = int_range(rate)
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
# late binding librosa and its dependencies
|
||||
from librosa.core import resample # pylint: disable=import-outside-toplevel
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
rate = pick_value_from_range(self.rate, clock=clock)
|
||||
audio = sample.audio
|
||||
orig_len = len(audio)
|
||||
audio = np.swapaxes(audio, 0, 1)
|
||||
audio = resample(audio, sample.audio_format.rate, rate)
|
||||
audio = resample(audio, rate, sample.audio_format.rate)
|
||||
audio = np.swapaxes(audio, 0, 1)[0:orig_len]
|
||||
sample.audio = audio
|
||||
|
||||
|
||||
class Volume(SampleAugmentation):
|
||||
"""See "Volume augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, dbfs=3.0103):
|
||||
super(Volume, self).__init__(p)
|
||||
self.target_dbfs = float_range(dbfs)
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
|
||||
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
|
||||
|
||||
|
||||
class PitchAndTempo(GraphAugmentation):
|
||||
"""See "Pitch and tempo augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, tempo=1.2, pitch=(1.075, 1.075, 0.125)):
|
||||
super(PitchAndTempo, self).__init__(p, domain='spectrogram')
|
||||
self.tempo = float_range(tempo)
|
||||
self.pitch = float_range(pitch)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
original_shape = tf.shape(tensor)
|
||||
pitch = tf_pick_value_from_range(self.pitch, clock=clock)
|
||||
tempo = tf.math.maximum(1.0, tf_pick_value_from_range(self.tempo, clock=clock))
|
||||
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32)
|
||||
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / tempo, tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, new_freq_size])
|
||||
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug,
|
||||
offset_height=0,
|
||||
offset_width=0,
|
||||
target_height=tf.shape(spectrogram_aug)[1],
|
||||
target_width=tf.math.minimum(original_shape[2], new_freq_size))
|
||||
spectrogram_aug = tf.cond(pitch < 1,
|
||||
lambda: tf.image.pad_to_bounding_box(spectrogram_aug,
|
||||
offset_height=0,
|
||||
offset_width=0,
|
||||
target_height=tf.shape(spectrogram_aug)[1],
|
||||
target_width=original_shape[2]),
|
||||
lambda: spectrogram_aug)
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
class Speed(GraphAugmentation):
|
||||
"""See "Speed augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, factor=1.1):
|
||||
super(Speed, self).__init__(p, domain='spectrogram')
|
||||
self.factor = float_range(factor)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
factor = tf_pick_value_from_range(self.factor, clock=clock)
|
||||
original_shape = tf.shape(tensor)
|
||||
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]])
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
class Warp(GraphAugmentation):
|
||||
"""See "Warp augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, shift=100.0, order=3, nbp=1, ncp=1, regularization_weight=0.0):
|
||||
super(Warp, self).__init__(p, domain='spectrogram')
|
||||
self.shift = float_range(shift)
|
||||
self.order = int_range(order)
|
||||
self.nbp = int_range(nbp)
|
||||
self.ncp = int_range(ncp)
|
||||
# Making this a value-range is impossible, as it would get a tensor which would downstream be used as parameter
|
||||
# of a comparison inside tensorflow.contrib.image.python.ops.interpolate_spline. This is not supported.
|
||||
self.regularization_weight = float(regularization_weight)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
from .flags import FLAGS # pylint: disable=import-outside-toplevel
|
||||
from .sparse_image_warp import sparse_image_warp # pylint: disable=import-outside-toplevel
|
||||
|
||||
# reshape to fit `sparse_image_warp`'s input shape (1, time steps, freq, 1), batch_size must be 1
|
||||
expanded_spectrogram = tf.expand_dims(tensor, -1)
|
||||
original_shape = tf.shape(expanded_spectrogram)
|
||||
tau, freq_size = original_shape[1], original_shape[2]
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
|
||||
shift = tf_pick_value_from_range(self.shift, clock=clock)
|
||||
shift *= FLAGS.audio_sample_rate / (FLAGS.feature_win_step * 1000.0) # number of windows
|
||||
shift = tf.math.minimum(tf.cast(shift, dtype=tf.int32), tf.math.floordiv(tau, 2) - 1) # to protect short audio
|
||||
nbp = tf_pick_value_from_range(self.nbp, clock=clock)
|
||||
ncp = tf_pick_value_from_range(self.ncp, clock=clock)
|
||||
# workaround for missing stateless shuffle support
|
||||
frequencies = tf.random.stateless_uniform([2 * ncp], seed, minval=1, maxval=freq_size - 2, dtype=tf.int32)
|
||||
frequencies = tf.unique(tf.concat([frequencies, tf.range(1, limit=freq_size - 3)], axis=0))[0][0:ncp]
|
||||
source_max = tau - shift
|
||||
source_min = tf.math.minimum(source_max - ncp, shift)
|
||||
# workaround for missing stateless shuffle support
|
||||
src_times = tf.random.stateless_uniform([2 * ncp], seed, minval=source_min, maxval=source_max, dtype=tf.int32)
|
||||
src_times = tf.unique(tf.concat([src_times, tf.range(1, limit=source_max)], axis=0))[0][0:ncp]
|
||||
dst_times = src_times + tf.random.stateless_uniform([ncp], seed, minval=-shift, maxval=shift, dtype=tf.int32)
|
||||
scp_locations = tf.cast([tf.transpose(tf.stack([src_times, frequencies]))], dtype=tf.float32)
|
||||
dcp_locations = tf.cast([tf.transpose(tf.stack([dst_times, frequencies]))], dtype=tf.float32)
|
||||
|
||||
order = tf_pick_value_from_range(self.order, clock=clock)
|
||||
order = tf.math.maximum(3, order) # prevents "Input matrix is not invertible." exception
|
||||
order = tf.cast(order, tf.float32)
|
||||
|
||||
spectrogram_aug, _ = sparse_image_warp(expanded_spectrogram,
|
||||
source_control_point_locations=scp_locations,
|
||||
dest_control_point_locations=dcp_locations,
|
||||
interpolation_order=order,
|
||||
regularization_weight=self.regularization_weight,
|
||||
num_boundary_points=nbp)
|
||||
return tf.reshape(spectrogram_aug, shape=(1, -1, freq_size))
|
||||
|
||||
|
||||
class FrequencyMask(GraphAugmentation):
|
||||
"""See "Frequency mask augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, n=3, size=2):
|
||||
super(FrequencyMask, self).__init__(p, domain='spectrogram')
|
||||
self.n = int_range(n) # pylint: disable=invalid-name
|
||||
self.size = int_range(size)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
time_max = tf.shape(tensor)[1]
|
||||
freq_max = tf.shape(tensor)[2]
|
||||
n = tf_pick_value_from_range(self.n, clock=clock)
|
||||
|
||||
def body(i, spectrogram_aug):
|
||||
size = tf_pick_value_from_range(self.size, clock=clock)
|
||||
size = tf.math.maximum(1, tf.math.minimum(freq_max - 1, size))
|
||||
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
|
||||
f0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=freq_max - size, dtype=tf.dtypes.int32)
|
||||
freq_mask = tf.concat([tf.ones([1, time_max, f0]),
|
||||
tf.zeros([1, time_max, size]),
|
||||
tf.ones([1, time_max, freq_max - f0 - size])], axis=2)
|
||||
return i + 1, spectrogram_aug * freq_mask
|
||||
|
||||
return tf.while_loop(lambda i, spectrogram_aug: i < n, body, (0, tensor))[1]
|
||||
|
||||
|
||||
class TimeMask(GraphAugmentation):
|
||||
"""See "Time mask augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0):
|
||||
super(TimeMask, self).__init__(p, domain=domain)
|
||||
self.n = int_range(n) # pylint: disable=invalid-name
|
||||
self.size = float_range(size)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
from .flags import FLAGS # pylint: disable=import-outside-toplevel
|
||||
time_factor = FLAGS.audio_sample_rate / 1000.0 # samples per ms
|
||||
if self.domain != 'signal':
|
||||
time_factor /= FLAGS.feature_win_step # windows per ms
|
||||
time_max = tf.shape(tensor)[0] if self.domain == 'signal' else tf.shape(tensor)[1]
|
||||
n = tf_pick_value_from_range(self.n, clock=clock)
|
||||
|
||||
def body(i, augmented):
|
||||
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * time_factor, dtype=tf.int32)
|
||||
size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size))
|
||||
tf.print(size)
|
||||
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
|
||||
t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32)
|
||||
rest = time_max - t0 - size
|
||||
if self.domain == 'spectrogram':
|
||||
fm = tf.shape(tensor)[2]
|
||||
time_mask = tf.concat([tf.ones([1, t0, fm]), tf.zeros([1, size, fm]), tf.ones([1, rest, fm])], axis=1)
|
||||
elif self.domain == 'signal':
|
||||
time_mask = tf.concat([tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0)
|
||||
else:
|
||||
time_mask = tf.concat([tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1)
|
||||
return i + 1, augmented * time_mask
|
||||
|
||||
return tf.while_loop(lambda i, augmented: i < n, body, (0, tensor))[1]
|
||||
|
||||
|
||||
class Dropout(GraphAugmentation):
|
||||
"""See "Dropout augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
|
||||
super(Dropout, self).__init__(p, domain=domain)
|
||||
self.rate = float_range(rate)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
rate = tf_pick_value_from_range(self.rate, clock=clock)
|
||||
rate = tf.math.maximum(0.0, rate)
|
||||
factors = tf.random.stateless_uniform(tf.shape(tensor),
|
||||
(clock * tf.int32.min, clock * tf.int32.max),
|
||||
minval=0.0,
|
||||
maxval=1.0,
|
||||
dtype=tf.float32)
|
||||
return tensor * tf.math.sign(tf.math.floor(factors + rate))
|
||||
|
||||
|
||||
class Add(GraphAugmentation):
|
||||
"""See "Add augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, domain='features', stddev=5):
|
||||
super(Add, self).__init__(p, domain=domain)
|
||||
self.stddev = float_range(stddev)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
return tensor + tf.random.stateless_normal(tf.shape(tensor), seed, mean=0.0, stddev=stddev)
|
||||
|
||||
|
||||
class Multiply(GraphAugmentation):
|
||||
"""See "Multiply augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, domain='features', stddev=5):
|
||||
super(Multiply, self).__init__(p, domain=domain)
|
||||
self.stddev = float_range(stddev)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
return tensor * tf.random.stateless_normal(tf.shape(tensor), seed, mean=1.0, stddev=stddev)
|
|
@ -2,7 +2,6 @@ from __future__ import absolute_import, division, print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
@ -13,6 +12,7 @@ from .gpu import get_available_gpus
|
|||
from .logging import log_error, log_warn
|
||||
from .text import Alphabet, UTF8Alphabet
|
||||
from .helpers import parse_file_size
|
||||
from .augmentations import parse_augmentations
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
|
@ -30,6 +30,17 @@ Config = ConfigSingleton() # pylint: disable=invalid-name
|
|||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
|
||||
# Augmentations
|
||||
c.augmentations = parse_augmentations(FLAGS.augment)
|
||||
if len(c.augmentations) > 0 and FLAGS.feature_cache is not None and FLAGS.cache_for_epochs == 0:
|
||||
log_warn('Due to current feature-cache settings the exact same sample augmentations of the first '
|
||||
'epoch will be repeated on all following epochs. This could lead to unintended over-fitting. '
|
||||
'You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs.')
|
||||
|
||||
# Caching
|
||||
if FLAGS.cache_for_epochs == 1:
|
||||
log_warn('--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it.')
|
||||
|
||||
# Read-buffer
|
||||
FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
@ -11,13 +12,13 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
|||
from .config import Config
|
||||
from .text import text_to_char_array
|
||||
from .flags import FLAGS
|
||||
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from .augmentations import apply_sample_augmentations, apply_graph_augmentations
|
||||
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
|
||||
from .sample_collections import samples_from_sources, augment_samples
|
||||
from .sample_collections import samples_from_sources
|
||||
from .helpers import remember_exception, MEGABYTE
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
||||
def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmentations=None, sample_id=None):
|
||||
if train_phase:
|
||||
# We need the lambdas to make TensorFlow happy.
|
||||
# pylint: disable=unnecessary-lambda
|
||||
|
@ -27,73 +28,48 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
|||
lambda: tf.no_op(),
|
||||
name='matching_sample_rate')
|
||||
|
||||
spectrogram = contrib_audio.audio_spectrogram(samples,
|
||||
if train_phase and augmentations is not None:
|
||||
audio = apply_graph_augmentations('signal', audio, augmentations, clock=clock)
|
||||
|
||||
spectrogram = contrib_audio.audio_spectrogram(audio,
|
||||
window_size=Config.audio_window_samples,
|
||||
stride=Config.audio_step_samples,
|
||||
magnitude_squared=True)
|
||||
|
||||
# Data Augmentations
|
||||
if train_phase:
|
||||
if FLAGS.augmentation_spec_dropout_keeprate < 1:
|
||||
spectrogram = augment_dropout(spectrogram,
|
||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||
if train_phase and augmentations is not None:
|
||||
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, clock=clock)
|
||||
|
||||
# sparse warp must before freq/time masking
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
spectrogram = augment_sparse_warp(spectrogram,
|
||||
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
|
||||
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
|
||||
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
|
||||
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points,
|
||||
num_control_points=FLAGS.augmentation_sparse_warp_num_control_points)
|
||||
features = contrib_audio.mfcc(spectrogram=spectrogram,
|
||||
sample_rate=sample_rate,
|
||||
dct_coefficient_count=Config.n_input,
|
||||
upper_frequency_limit=FLAGS.audio_sample_rate / 2)
|
||||
features = tf.reshape(features, [-1, Config.n_input])
|
||||
|
||||
if FLAGS.augmentation_freq_and_time_masking:
|
||||
spectrogram = augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
|
||||
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
|
||||
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
|
||||
if train_phase and augmentations is not None:
|
||||
features = apply_graph_augmentations('features', features, augmentations, clock=clock)
|
||||
|
||||
if FLAGS.augmentation_pitch_and_tempo_scaling:
|
||||
spectrogram = augment_pitch_and_tempo(spectrogram,
|
||||
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
|
||||
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
|
||||
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
|
||||
|
||||
if FLAGS.augmentation_speed_up_std > 0:
|
||||
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
||||
|
||||
mfccs = contrib_audio.mfcc(spectrogram=spectrogram,
|
||||
sample_rate=sample_rate,
|
||||
dct_coefficient_count=Config.n_input,
|
||||
upper_frequency_limit=FLAGS.audio_sample_rate/2)
|
||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||
|
||||
return mfccs, tf.shape(input=mfccs)[0]
|
||||
return features, tf.shape(input=features)[0]
|
||||
|
||||
|
||||
def audio_to_features(audio, sample_rate, train_phase=False, sample_id=None):
|
||||
features, features_len = samples_to_mfccs(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
|
||||
|
||||
if train_phase:
|
||||
if FLAGS.data_aug_features_multiplicative > 0:
|
||||
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
|
||||
|
||||
if FLAGS.data_aug_features_additive > 0:
|
||||
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
|
||||
|
||||
return features, features_len
|
||||
|
||||
|
||||
def audiofile_to_features(wav_filename, train_phase=False):
|
||||
def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentations=None):
|
||||
samples = tf.io.read_file(wav_filename)
|
||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||
return audio_to_features(decoded.audio, decoded.sample_rate, train_phase=train_phase, sample_id=wav_filename)
|
||||
return audio_to_features(decoded.audio,
|
||||
decoded.sample_rate,
|
||||
clock=clock,
|
||||
train_phase=train_phase,
|
||||
augmentations=augmentations,
|
||||
sample_id=wav_filename)
|
||||
|
||||
|
||||
def entry_to_features(sample_id, audio, sample_rate, transcript, train_phase=False):
|
||||
def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audio_to_features(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
|
||||
features, features_len = audio_to_features(audio,
|
||||
sample_rate,
|
||||
clock=clock,
|
||||
train_phase=train_phase,
|
||||
augmentations=augmentations,
|
||||
sample_id=sample_id)
|
||||
sparse_transcript = tf.SparseTensor(*transcript)
|
||||
return sample_id, features, features_len, sparse_transcript
|
||||
|
||||
|
@ -109,25 +85,32 @@ def to_sparse_tuple(sequence):
|
|||
|
||||
def create_dataset(sources,
|
||||
batch_size,
|
||||
repetitions=1,
|
||||
augmentation_specs=None,
|
||||
enable_cache=False,
|
||||
epochs=1,
|
||||
augmentations=None,
|
||||
cache_path=None,
|
||||
train_phase=False,
|
||||
exception_box=None,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE):
|
||||
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
||||
|
||||
def generate_values():
|
||||
epoch = epoch_counter['epoch']
|
||||
if train_phase:
|
||||
epoch_counter['epoch'] += 1
|
||||
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
|
||||
samples = augment_samples(samples,
|
||||
repetitions=repetitions,
|
||||
augmentation_specs=augmentation_specs,
|
||||
buffering=buffering,
|
||||
process_ahead=2 * batch_size if process_ahead is None else process_ahead)
|
||||
for sample in samples:
|
||||
num_samples = len(samples)
|
||||
samples = apply_sample_augmentations(samples,
|
||||
augmentations,
|
||||
buffering=buffering,
|
||||
process_ahead=2 * batch_size if process_ahead is None else process_ahead,
|
||||
clock=epoch / epochs,
|
||||
final_clock=(epoch + 1) / epochs)
|
||||
for sample_index, sample in enumerate(samples):
|
||||
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
|
||||
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
|
||||
transcript = to_sparse_tuple(transcript)
|
||||
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript
|
||||
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript, clock
|
||||
|
||||
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail
|
||||
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
|
||||
|
@ -143,13 +126,13 @@ def create_dataset(sources,
|
|||
sample_ids = sample_ids.batch(batch_size)
|
||||
return tf.data.Dataset.zip((sample_ids, features, transcripts))
|
||||
|
||||
process_fn = partial(entry_to_features, train_phase=train_phase)
|
||||
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)
|
||||
|
||||
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
|
||||
output_types=(tf.string, tf.float32, tf.int32,
|
||||
(tf.int64, tf.int32, tf.int64)))
|
||||
(tf.int64, tf.int32, tf.int64), tf.float64))
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
||||
if enable_cache:
|
||||
if cache_path is not None:
|
||||
dataset = dataset.cache(cache_path)
|
||||
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
|
||||
.prefetch(len(Config.available_devices)))
|
||||
|
@ -172,7 +155,7 @@ def split_audio_file(audio_path,
|
|||
yield time_start, time_end, samples
|
||||
|
||||
def to_mfccs(time_start, time_end, samples):
|
||||
features, features_len = samples_to_mfccs(samples, audio_format.rate)
|
||||
features, features_len = audio_to_features(samples, audio_format.rate)
|
||||
return time_start, time_end, features, features_len
|
||||
|
||||
def create_batch_set(bs, criteria):
|
||||
|
|
|
@ -19,6 +19,7 @@ def create_flags():
|
|||
|
||||
f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)')
|
||||
f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.')
|
||||
f.DEFINE_integer('cache_for_epochs', 0, 'after how many epochs the feature cache is invalidated again - 0 for "never"')
|
||||
|
||||
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
|
@ -28,32 +29,6 @@ def create_flags():
|
|||
# ================
|
||||
|
||||
f.DEFINE_multi_string('augment', None, 'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"')
|
||||
f.DEFINE_integer('augmentations_per_epoch', 1, 'how often the train set should be repeated and re-augmented per epoch')
|
||||
|
||||
f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise')
|
||||
f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise')
|
||||
|
||||
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
|
||||
|
||||
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp. USE OF THIS FLAG IS UNSUPPORTED, enable sparse warp will increase training time drastically, and the paper also mentioned that this is not a major factor to improve accuracy.')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
|
||||
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
|
||||
|
||||
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
|
||||
|
||||
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
|
||||
|
||||
f.DEFINE_boolean('augmentation_pitch_and_tempo_scaling', False, 'whether to use spectrogram speed and tempo scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
|
|
@ -174,3 +174,18 @@ def pick_value_from_range(value_range, clock=None):
|
|||
value = value_range.start + clock * (value_range.end - value_range.start)
|
||||
value = random.uniform(value - value_range.r, value + value_range.r)
|
||||
return round(value) if isinstance(value_range.start, int) else value
|
||||
|
||||
|
||||
def tf_pick_value_from_range(value_range, clock=None, double_precision=False):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
clock = (tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64) if clock is None
|
||||
else tf.maximum(tf.constant(0.0, dtype=tf.float64), tf.minimum(tf.constant(1.0, dtype=tf.float64), clock)))
|
||||
value = value_range.start + clock * (value_range.end - value_range.start)
|
||||
value = tf.random.stateless_uniform([],
|
||||
minval=value - value_range.r,
|
||||
maxval=value + value_range.r,
|
||||
seed=(clock * tf.int32.min, clock * tf.int32.max),
|
||||
dtype=tf.float64)
|
||||
if isinstance(value_range.start, int):
|
||||
return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32)
|
||||
return tf.cast(value, tf.float64 if double_precision else tf.float32)
|
||||
|
|
|
@ -2,14 +2,12 @@
|
|||
import os
|
||||
import csv
|
||||
import json
|
||||
import random
|
||||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from .signal_augmentations import parse_augmentation
|
||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved, LimitingPool
|
||||
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, AUDIO_TYPE_NP, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension
|
||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension
|
||||
|
||||
BIG_ENDIAN = 'big'
|
||||
INT_SIZE = 4
|
||||
|
@ -416,88 +414,3 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
|||
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
|
||||
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
|
||||
return Interleaved(*cols, key=lambda s: s.duration)
|
||||
|
||||
|
||||
class PreparationContext:
|
||||
def __init__(self, target_audio_type, augmentations):
|
||||
self.target_audio_type = target_audio_type
|
||||
self.augmentations = augmentations
|
||||
|
||||
|
||||
AUGMENTATION_CONTEXT = None
|
||||
|
||||
|
||||
def _init_augmentation_worker(preparation_context):
|
||||
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
|
||||
AUGMENTATION_CONTEXT = preparation_context
|
||||
|
||||
|
||||
def _augment_sample(timed_sample, context=None):
|
||||
context = AUGMENTATION_CONTEXT if context is None else context
|
||||
sample, clock = timed_sample
|
||||
for augmentation in context.augmentations:
|
||||
if random.random() < augmentation.probability:
|
||||
augmentation.apply(sample, clock)
|
||||
sample.change_audio_type(new_audio_type=context.target_audio_type)
|
||||
return sample
|
||||
|
||||
|
||||
def augment_samples(samples,
|
||||
audio_type=AUDIO_TYPE_NP,
|
||||
augmentation_specs=None,
|
||||
buffering=BUFFER_SIZE,
|
||||
process_ahead=None,
|
||||
repetitions=1,
|
||||
fixed_clock=None):
|
||||
"""
|
||||
Prepares samples for being used during training.
|
||||
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samples : Sample enumeration
|
||||
Typically produced by samples_from_sources.
|
||||
audio_type : str
|
||||
Target audio-type to convert samples to. See util.audio.Sample.__init__ .
|
||||
augmentation_specs : list of str
|
||||
Augmentation specifications like ["reverb[delay=20.0,decay=-20]", "volume"]. See TRAINING.rst.
|
||||
buffering : int
|
||||
Read-buffer size to use while reading files.
|
||||
process_ahead : int
|
||||
Number of samples to pre-process ahead of time.
|
||||
repetitions : int
|
||||
How often the input sample enumeration should get repeated for being re-augmented.
|
||||
fixed_clock : float
|
||||
Sets the internal clock to a value between 0.0 (beginning of epoch) and 1.0 (end of epoch).
|
||||
Setting this to a number is used for simulating augmentations at a certain epoch-time.
|
||||
If kept at None (default), the internal clock will run regularly from 0.0 to 1.0,
|
||||
hence preparing them for training.
|
||||
|
||||
Returns
|
||||
-------
|
||||
iterable of util.sample_collections.LabeledSample or util.audio.Sample
|
||||
"""
|
||||
def timed_samples():
|
||||
for repetition in range(repetitions):
|
||||
for sample_index, sample in enumerate(samples):
|
||||
if fixed_clock is None:
|
||||
yield sample, (repetition * len(samples) + sample_index) / (repetitions * len(samples))
|
||||
else:
|
||||
yield sample, fixed_clock
|
||||
|
||||
augmentations = [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
|
||||
try:
|
||||
for augmentation in augmentations:
|
||||
augmentation.start(buffering=buffering)
|
||||
context = PreparationContext(audio_type, augmentations)
|
||||
if process_ahead == 0:
|
||||
for timed_sample in timed_samples():
|
||||
yield _augment_sample(timed_sample, context=context)
|
||||
else:
|
||||
with LimitingPool(process_ahead=process_ahead,
|
||||
initializer=_init_augmentation_worker,
|
||||
initargs=(context,)) as pool:
|
||||
yield from pool.imap(_augment_sample, timed_samples())
|
||||
finally:
|
||||
for augmentation in augmentations:
|
||||
augmentation.stop()
|
||||
|
|
|
@ -1,222 +0,0 @@
|
|||
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from multiprocessing import Queue, Process
|
||||
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
|
||||
from .helpers import int_range, float_range, pick_value_from_range, MEGABYTE
|
||||
|
||||
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z]+)(\[(?P<params>.*)\])?$')
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
|
||||
|
||||
class Augmentation:
|
||||
def __init__(self, p=1.0):
|
||||
self.probability = float(p)
|
||||
|
||||
def start(self, buffering=BUFFER_SIZE):
|
||||
pass
|
||||
|
||||
def apply(self, sample, clock):
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
|
||||
def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
|
||||
"""
|
||||
As the central distribution point for overlay samples this function is supposed to run in one process only.
|
||||
This ensures that samples are not used twice if not required.
|
||||
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
|
||||
These are then doing decompression, potential conversion and overlaying in parallel.
|
||||
"""
|
||||
# preventing cyclic import problems
|
||||
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
|
||||
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
|
||||
while True:
|
||||
for sample in samples:
|
||||
queue.put(sample)
|
||||
|
||||
|
||||
class Overlay(Augmentation):
|
||||
"""See "Overlay augmentation" in TRAINING.rst"""
|
||||
def __init__(self, source, p=1.0, snr=3.0, layers=1):
|
||||
super(Overlay, self).__init__(p)
|
||||
self.source = source
|
||||
self.snr = float_range(snr)
|
||||
self.layers = int_range(layers)
|
||||
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
|
||||
self.current_sample = None
|
||||
self.enqueue_process = None
|
||||
|
||||
def start(self, buffering=BUFFER_SIZE):
|
||||
self.enqueue_process = Process(target=_enqueue_overlay_samples,
|
||||
args=(self.source, self.queue),
|
||||
kwargs={'buffering': buffering})
|
||||
self.enqueue_process.start()
|
||||
|
||||
def apply(self, sample, clock):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
n_layers = pick_value_from_range(self.layers, clock=clock)
|
||||
audio = sample.audio
|
||||
overlay_data = np.zeros_like(audio)
|
||||
for _ in range(n_layers):
|
||||
overlay_offset = 0
|
||||
while overlay_offset < len(audio):
|
||||
if self.current_sample is None:
|
||||
next_overlay_sample = self.queue.get()
|
||||
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
self.current_sample = next_overlay_sample.audio
|
||||
n_required = len(audio) - overlay_offset
|
||||
n_current = len(self.current_sample)
|
||||
if n_required >= n_current: # take it completely
|
||||
overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample
|
||||
overlay_offset += n_current
|
||||
self.current_sample = None
|
||||
else: # take required slice from head and keep tail for next layer or sample
|
||||
overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required]
|
||||
overlay_offset += n_required
|
||||
self.current_sample = self.current_sample[n_required:]
|
||||
snr_db = pick_value_from_range(self.snr, clock=clock)
|
||||
orig_dbfs = max_dbfs(audio)
|
||||
overlay_gain = orig_dbfs - max_dbfs(overlay_data) - snr_db
|
||||
audio += overlay_data * gain_db_to_ratio(overlay_gain)
|
||||
sample.audio = normalize_audio(audio, dbfs=orig_dbfs)
|
||||
|
||||
def stop(self):
|
||||
if self.enqueue_process is not None:
|
||||
self.enqueue_process.terminate()
|
||||
|
||||
|
||||
class Reverb(Augmentation):
|
||||
"""See "Reverb augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, delay=20.0, decay=10.0):
|
||||
super(Reverb, self).__init__(p)
|
||||
self.delay = float_range(delay)
|
||||
self.decay = float_range(decay)
|
||||
|
||||
def apply(self, sample, clock):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
audio = np.array(sample.audio, dtype=np.float64)
|
||||
orig_dbfs = max_dbfs(audio)
|
||||
delay = pick_value_from_range(self.delay, clock=clock)
|
||||
decay = pick_value_from_range(self.decay, clock=clock)
|
||||
decay = gain_db_to_ratio(-decay)
|
||||
result = np.copy(audio)
|
||||
primes = [17, 19, 23, 29, 31]
|
||||
for delay_prime in primes: # primes to minimize comb filter interference
|
||||
layer = np.copy(audio)
|
||||
n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0)
|
||||
n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero
|
||||
for w_index in range(0, math.floor(len(audio) / n_delay)):
|
||||
w1 = w_index * n_delay
|
||||
w2 = (w_index + 1) * n_delay
|
||||
width = min(len(audio) - w2, n_delay) # last window could be smaller
|
||||
layer[w2:w2 + width] += decay * layer[w1:w1 + width]
|
||||
result += layer
|
||||
audio = normalize_audio(result, dbfs=orig_dbfs)
|
||||
sample.audio = np.array(audio, dtype=np.float32)
|
||||
|
||||
|
||||
class Resample(Augmentation):
|
||||
"""See "Resample augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, rate=8000):
|
||||
super(Resample, self).__init__(p)
|
||||
self.rate = int_range(rate)
|
||||
|
||||
def apply(self, sample, clock):
|
||||
# late binding librosa and its dependencies
|
||||
from librosa.core import resample # pylint: disable=import-outside-toplevel
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
rate = pick_value_from_range(self.rate, clock=clock)
|
||||
audio = sample.audio
|
||||
orig_len = len(audio)
|
||||
audio = np.swapaxes(audio, 0, 1)
|
||||
audio = resample(audio, sample.audio_format.rate, rate)
|
||||
audio = resample(audio, rate, sample.audio_format.rate)
|
||||
audio = np.swapaxes(audio, 0, 1)[0:orig_len]
|
||||
sample.audio = audio
|
||||
|
||||
|
||||
class Codec(Augmentation):
|
||||
"""See "Codec augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, bitrate=3200):
|
||||
super(Codec, self).__init__(p)
|
||||
self.bitrate = int_range(bitrate)
|
||||
|
||||
def apply(self, sample, clock):
|
||||
bitrate = pick_value_from_range(self.bitrate, clock=clock)
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream
|
||||
|
||||
|
||||
class Gaps(Augmentation):
|
||||
"""See "Gaps augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, n=1, size=50.0):
|
||||
super(Gaps, self).__init__(p)
|
||||
self.n_gaps = int_range(n)
|
||||
self.size = float_range(size)
|
||||
|
||||
def apply(self, sample, clock):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
audio = sample.audio
|
||||
n_gaps = pick_value_from_range(self.n_gaps, clock=clock)
|
||||
for _ in range(n_gaps):
|
||||
size = pick_value_from_range(self.size, clock=clock)
|
||||
size = int(size * sample.audio_format.rate / 1000.0)
|
||||
size = min(size, len(audio) // 10) # a gap should never exceed 10 percent of the audio
|
||||
offset = random.randint(0, max(0, len(audio) - size - 1))
|
||||
audio[offset:offset + size] = 0
|
||||
sample.audio = audio
|
||||
|
||||
|
||||
class Volume(Augmentation):
|
||||
"""See "Volume augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, dbfs=3.0103):
|
||||
super(Volume, self).__init__(p)
|
||||
self.target_dbfs = float_range(dbfs)
|
||||
|
||||
def apply(self, sample, clock):
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
|
||||
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
|
||||
|
||||
|
||||
def parse_augmentation(augmentation_spec):
|
||||
"""
|
||||
Parses an augmentation specification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
augmentation_spec : str
|
||||
Augmentation specification like "reverb[delay=20.0,decay=-20]".
|
||||
|
||||
Returns
|
||||
-------
|
||||
Instance of an augmentation class from util.signal_augmentations.*.
|
||||
"""
|
||||
match = SPEC_PARSER.match(augmentation_spec)
|
||||
if not match:
|
||||
raise ValueError('Augmentation specification has wrong format')
|
||||
cls_name = match.group('cls')
|
||||
cls_name = cls_name[0].upper() + cls_name[1:]
|
||||
augmentation_cls = globals()[cls_name] if cls_name in globals() else None
|
||||
if not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation:
|
||||
raise ValueError('Unknown augmentation: {}'.format(cls_name))
|
||||
parameters = match.group('params')
|
||||
parameters = [] if parameters is None else parameters.split(',')
|
||||
args = []
|
||||
kwargs = {}
|
||||
for parameter in parameters:
|
||||
pair = tuple(list(map(str.strip, (parameter.split('=')))))
|
||||
if len(pair) == 1:
|
||||
args.append(pair)
|
||||
elif len(pair) == 2:
|
||||
kwargs[pair[0]] = pair[1]
|
||||
else:
|
||||
raise ValueError('Unable to parse augmentation value assignment')
|
||||
return augmentation_cls(*args, **kwargs)
|
|
@ -1,127 +0,0 @@
|
|||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from .sparse_image_warp import sparse_image_warp
|
||||
|
||||
def augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=30,
|
||||
time_masking_para=10,
|
||||
frequency_mask_num=3,
|
||||
time_mask_num=3):
|
||||
time_max = tf.shape(spectrogram)[1]
|
||||
freq_max = tf.shape(spectrogram)[2]
|
||||
# Frequency masking
|
||||
for _ in range(frequency_mask_num):
|
||||
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
|
||||
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
|
||||
value_ones_freq_prev = tf.ones(shape=[1, time_max, f0])
|
||||
value_zeros_freq = tf.zeros(shape=[1, time_max, f])
|
||||
value_ones_freq_next = tf.ones(shape=[1, time_max, freq_max-(f0+f)])
|
||||
freq_mask = tf.concat([value_ones_freq_prev, value_zeros_freq, value_ones_freq_next], axis=2)
|
||||
# mel_spectrogram[:, f0:f0 + f, :] = 0 #can't assign to tensor
|
||||
# mel_spectrogram[:, f0:f0 + f, :] = value_zeros_freq #can't assign to tensor
|
||||
spectrogram = spectrogram*freq_mask
|
||||
|
||||
# Time masking
|
||||
for _ in range(time_mask_num):
|
||||
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
|
||||
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)
|
||||
value_zeros_time_prev = tf.ones(shape=[1, t0, freq_max])
|
||||
value_zeros_time = tf.zeros(shape=[1, t, freq_max])
|
||||
value_zeros_time_next = tf.ones(shape=[1, time_max-(t0+t), freq_max])
|
||||
time_mask = tf.concat([value_zeros_time_prev, value_zeros_time, value_zeros_time_next], axis=1)
|
||||
# mel_spectrogram[:, :, t0:t0 + t] = 0 #can't assign to tensor
|
||||
# mel_spectrogram[:, :, t0:t0 + t] = value_zeros_time #can't assign to tensor
|
||||
spectrogram = spectrogram*time_mask
|
||||
|
||||
return spectrogram
|
||||
|
||||
def augment_pitch_and_tempo(spectrogram,
|
||||
max_tempo=1.2,
|
||||
max_pitch=1.1,
|
||||
min_pitch=0.95):
|
||||
original_shape = tf.shape(spectrogram)
|
||||
choosen_pitch = tf.random.uniform(shape=(), minval=min_pitch, maxval=max_pitch)
|
||||
choosen_tempo = tf.random.uniform(shape=(), minval=1, maxval=max_tempo)
|
||||
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32)*choosen_pitch, tf.int32)
|
||||
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32)/(choosen_tempo), tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_time_size, new_freq_size])
|
||||
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=tf.shape(spectrogram_aug)[1], target_width=tf.minimum(original_shape[2], new_freq_size))
|
||||
spectrogram_aug = tf.cond(choosen_pitch < 1,
|
||||
lambda: tf.image.pad_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0,
|
||||
target_height=tf.shape(spectrogram_aug)[1], target_width=original_shape[2]),
|
||||
lambda: spectrogram_aug)
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
def augment_speed_up(spectrogram,
|
||||
speed_std=0.1):
|
||||
original_shape = tf.shape(spectrogram)
|
||||
choosen_speed = tf.math.abs(tf.random.normal(shape=(), stddev=speed_std)) # abs makes sure the augmention will only speed up
|
||||
choosen_speed = 1 + choosen_speed
|
||||
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32), tf.int32)
|
||||
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32)/(choosen_speed), tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_time_size, new_freq_size])
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
def augment_dropout(spectrogram,
|
||||
keep_prob=0.95):
|
||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
||||
|
||||
|
||||
def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
|
||||
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
|
||||
Args:
|
||||
spectrogram: `[batch, time, frequency]` float `Tensor`
|
||||
time_warping_para: 'W' parameter in paper
|
||||
interpolation_order: used to put into `sparse_image_warp`
|
||||
regularization_weight: used to put into `sparse_image_warp`
|
||||
num_boundary_points: used to put into `sparse_image_warp`,
|
||||
default=1 means boundary points on 4 corners of the image
|
||||
num_control_points: number of control points
|
||||
Returns:
|
||||
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
|
||||
type as input image.
|
||||
"""
|
||||
# reshape to fit `sparse_image_warp`'s input shape
|
||||
# (1, time steps, freq, 1), batch_size must be 1
|
||||
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||
|
||||
original_shape = tf.shape(spectrogram)
|
||||
tau, freq_size = original_shape[1], original_shape[2]
|
||||
|
||||
# to protect short audio
|
||||
time_warping_para = tf.math.minimum(
|
||||
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
||||
|
||||
# don't choose boundary frequency
|
||||
choosen_freqs = tf.random.shuffle(
|
||||
tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
|
||||
|
||||
source_max = tau - time_warping_para
|
||||
source_min = tf.math.minimum(source_max - num_control_points, time_warping_para)
|
||||
|
||||
choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points]
|
||||
dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32)
|
||||
|
||||
sources = []
|
||||
dests = []
|
||||
for i in range(num_control_points):
|
||||
# generate source points `t` of time axis between (W, tau-W)
|
||||
rand_source_time = choosen_times[i]
|
||||
rand_dest_time = rand_source_time + dest_time_widths[i]
|
||||
|
||||
choosen_freq = choosen_freqs[i]
|
||||
sources.append([rand_source_time, choosen_freq])
|
||||
dests.append([rand_dest_time, choosen_freq])
|
||||
|
||||
source_control_point_locations = tf.cast([sources], tf.float32)
|
||||
dest_control_point_locations = tf.cast([dests], tf.float32)
|
||||
|
||||
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
||||
source_control_point_locations=source_control_point_locations,
|
||||
dest_control_point_locations=dest_control_point_locations,
|
||||
interpolation_order=interpolation_order,
|
||||
regularization_weight=regularization_weight,
|
||||
num_boundary_points=num_boundary_points)
|
||||
return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))
|
Загрузка…
Ссылка в новой задаче