зеркало из https://github.com/mozilla/DeepSpeech.git
Remote training I/O once more (#3437)
* Redo remote I/O changes once more; this time without messing with taskcluster
* Add bin changes
* Fix merge-induced issue?
* For the interleaved case with multiple collections, unpack audio on the fly
To reproduce the previous failure
rm data/smoke_test/ldc93s1.csv
rm data/smoke_test/ldc93s1.sdb
rm -rf /tmp/ldc93s1_cache_sdb_csv
rm -rf /tmp/ckpt_sdb_csv
rm -rf /tmp/train_sdb_csv
./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 16000
python -u DeepSpeech.py --noshow_progressbar --noearly_stop --train_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --train_batch_size 1 --feature_cache /tmp/ldc93s1_cache_sdb_csv --dev_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --dev_batch_size 1 --test_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --test_batch_size 1 --n_hidden 100 --epochs 109 --max_to_keep 1 --checkpoint_dir /tmp/ckpt_sdb_csv --learning_rate 0.001 --dropout_rate 0.05 --export_dir /tmp/train_sdb_csv --scorer_path data/smoke_test/pruned_lm.scorer --audio_sample_rate 16000
* Attempt to preserve length information with a wrapper around `map()`… this gets pretty python-y
* Call the right `__next__()`
* Properly implement the rest of the map wrappers here……
* Fix trailing whitespace situation and other linter complaints
* Remove data accidentally checked in
* Fix overlay augmentations
* Wavs must be open in rb mode if we're passing in an external file pointer -- this confused me
* Lint whitespace
* Revert "Fix trailing whitespace situation and other linter complaints"
This reverts commit c3c45397a2
.
* Fix linter issue but without such an aggressive diff
* Move unpack_maybe into sample_collections
* Use unpack_maybe in place of duplicate lambda
* Fix confusing comment
* Add clarifying comment for on-the-fly unpacking
This commit is contained in:
Родитель
18b66adf46
Коммит
6640cf2341
|
@ -15,8 +15,8 @@ def fail(message):
|
|||
|
||||
|
||||
def compare_samples():
|
||||
sample1 = load_sample(CLI_ARGS.sample1)
|
||||
sample2 = load_sample(CLI_ARGS.sample2)
|
||||
sample1 = load_sample(CLI_ARGS.sample1).unpack()
|
||||
sample2 = load_sample(CLI_ARGS.sample2).unpack()
|
||||
if sample1.audio_format != sample2.audio_format:
|
||||
fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format))
|
||||
if sample1.duration != sample2.duration:
|
||||
|
|
|
@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
|
|||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
|
@ -512,9 +513,10 @@ def train():
|
|||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
if not is_remote_path(FLAGS.save_checkpoint_dir):
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
with open_remote(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
|
@ -541,7 +543,7 @@ def train():
|
|||
feature_cache_index = FLAGS.feature_cache + '.index'
|
||||
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
|
||||
log_info('Invalidating feature cache')
|
||||
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
|
@ -810,13 +812,13 @@ def export():
|
|||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
if isdir_remote(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
remove_remote(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
|
@ -829,7 +831,7 @@ def export():
|
|||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
with open_remote(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
@ -840,7 +842,7 @@ def export():
|
|||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
with open_remote(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
@ -851,7 +853,7 @@ def export():
|
|||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
with open_remote(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
|
@ -873,8 +875,12 @@ def export():
|
|||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
if is_remote_path(export_dir):
|
||||
log_error("Cannot package remote path zip %s. Please do this manually." % export_dir)
|
||||
return
|
||||
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
|
@ -959,7 +965,7 @@ def main(_):
|
|||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
if listdir_remote(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
|
||||
from .helpers import LimitingPool
|
||||
from collections import namedtuple
|
||||
from .io import open_remote, remove_remote, copy_remote, is_remote_path
|
||||
|
||||
AudioFormat = namedtuple('AudioFormat', 'rate channels width')
|
||||
|
||||
|
@ -117,15 +118,19 @@ class Sample:
|
|||
self.audio_type = new_audio_type
|
||||
|
||||
|
||||
def _change_audio_type(sample_and_audio_type):
|
||||
sample, audio_type, bitrate = sample_and_audio_type
|
||||
def _unpack_and_change_audio_type(sample_and_audio_type):
|
||||
packed_sample, audio_type, bitrate = sample_and_audio_type
|
||||
if hasattr(packed_sample, 'unpack'):
|
||||
sample = packed_sample.unpack()
|
||||
else:
|
||||
sample = packed_sample
|
||||
sample.change_audio_type(audio_type, bitrate=bitrate)
|
||||
return sample
|
||||
|
||||
|
||||
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None):
|
||||
def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None):
|
||||
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
|
||||
yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type, bitrate), samples))
|
||||
yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples))
|
||||
|
||||
|
||||
def get_audio_type_from_extension(ext):
|
||||
|
@ -168,29 +173,45 @@ class AudioFile:
|
|||
self.audio_format = audio_format
|
||||
self.as_path = as_path
|
||||
self.open_file = None
|
||||
self.open_wav = None
|
||||
self.tmp_file_path = None
|
||||
self.tmp_src_file_path = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.audio_path.endswith('.wav'):
|
||||
self.open_file = wave.open(self.audio_path, 'r')
|
||||
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
|
||||
self.open_file = open_remote(self.audio_path, 'rb')
|
||||
self.open_wav = wave.open(self.open_file)
|
||||
if read_audio_format_from_wav_file(self.open_wav) == self.audio_format:
|
||||
if self.as_path:
|
||||
self.open_wav.close()
|
||||
self.open_file.close()
|
||||
return self.audio_path
|
||||
return self.open_file
|
||||
return self.open_wav
|
||||
self.open_wav.close()
|
||||
self.open_file.close()
|
||||
|
||||
# If the format isn't right, copy the file to local tmp dir and do the conversion on disk
|
||||
if is_remote_path(self.audio_path):
|
||||
_, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav')
|
||||
copy_remote(self.audio_path, self.tmp_src_file_path)
|
||||
self.audio_path = self.tmp_file_path
|
||||
|
||||
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
|
||||
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
|
||||
if self.as_path:
|
||||
return self.tmp_file_path
|
||||
self.open_file = wave.open(self.tmp_file_path, 'r')
|
||||
return self.open_file
|
||||
self.open_wav = wave.open(self.tmp_file_path, 'rb')
|
||||
return self.open_wav
|
||||
|
||||
def __exit__(self, *args):
|
||||
if not self.as_path:
|
||||
self.open_file.close()
|
||||
self.open_wav.close()
|
||||
if self.open_file:
|
||||
self.open_file.close()
|
||||
if self.tmp_file_path is not None:
|
||||
os.remove(self.tmp_file_path)
|
||||
if self.tmp_src_file_path is not None:
|
||||
os.remove(self.tmp_src_file_path)
|
||||
|
||||
|
||||
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
|
||||
|
@ -320,6 +341,7 @@ def read_opus(opus_file):
|
|||
|
||||
|
||||
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
|
||||
# wav_file is already a file-pointer here
|
||||
with wave.open(wav_file, 'wb') as wav_file_writer:
|
||||
wav_file_writer.setframerate(audio_format.rate)
|
||||
wav_file_writer.setnchannels(audio_format.channels)
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
from multiprocessing import Queue, Process
|
||||
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
|
||||
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
|
||||
from .sample_collections import samples_from_source
|
||||
from .sample_collections import samples_from_source, unpack_maybe
|
||||
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
|
||||
|
@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context):
|
|||
AUGMENTATION_CONTEXT = preparation_context
|
||||
|
||||
|
||||
def _load_and_augment_sample(timed_sample, context=None):
|
||||
sample, clock = timed_sample
|
||||
realized_sample = unpack_maybe(sample)
|
||||
return _augment_sample((realized_sample, clock), context)
|
||||
|
||||
|
||||
def _augment_sample(timed_sample, context=None):
|
||||
context = AUGMENTATION_CONTEXT if context is None else context
|
||||
sample, clock = timed_sample
|
||||
|
@ -213,12 +219,12 @@ def apply_sample_augmentations(samples,
|
|||
context = AugmentationContext(audio_type, augmentations)
|
||||
if process_ahead == 0:
|
||||
for timed_sample in timed_samples():
|
||||
yield _augment_sample(timed_sample, context=context)
|
||||
yield _load_and_augment_sample(timed_sample, context=context)
|
||||
else:
|
||||
with LimitingPool(process_ahead=process_ahead,
|
||||
initializer=_init_augmentation_worker,
|
||||
initargs=(context,)) as pool:
|
||||
yield from pool.imap(_augment_sample, timed_samples())
|
||||
yield from pool.imap(_load_and_augment_sample, timed_samples())
|
||||
finally:
|
||||
for augmentation in augmentations:
|
||||
augmentation.stop()
|
||||
|
@ -256,6 +262,7 @@ class Overlay(SampleAugmentation):
|
|||
self.enqueue_process.start()
|
||||
|
||||
def apply(self, sample, clock=0.0):
|
||||
sample = unpack_maybe(sample)
|
||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
n_layers = pick_value_from_range(self.layers, clock=clock)
|
||||
audio = sample.audio
|
||||
|
@ -265,6 +272,7 @@ class Overlay(SampleAugmentation):
|
|||
while overlay_offset < len(audio):
|
||||
if self.current_sample is None:
|
||||
next_overlay_sample = self.queue.get()
|
||||
next_overlay_sample = unpack_maybe(next_overlay_sample)
|
||||
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||
self.current_sample = next_overlay_sample.audio
|
||||
n_required = len(audio) - overlay_offset
|
||||
|
|
|
@ -19,6 +19,7 @@ import csv
|
|||
import os
|
||||
import sys
|
||||
import unicodedata
|
||||
from .io import open_remote
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -27,14 +28,14 @@ def main():
|
|||
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
|
||||
parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true")
|
||||
args = parser.parse_args()
|
||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
in_files = args.csv_files.split(",")
|
||||
|
||||
print("### Reading in the following transcript files: ###")
|
||||
print("### {} ###".format(in_files))
|
||||
|
||||
all_text = set()
|
||||
for in_file in in_files:
|
||||
with open(in_file, "r") as csv_file:
|
||||
with open_remote(in_file, "r") as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
try:
|
||||
next(reader, None) # skip the file header (i.e. "transcript")
|
||||
|
|
|
@ -13,7 +13,7 @@ from .gpu import get_available_gpus
|
|||
from .logging import log_error, log_warn
|
||||
from .helpers import parse_file_size
|
||||
from .augmentations import parse_augmentations
|
||||
|
||||
from .io import path_exists_remote
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
|
@ -139,7 +139,7 @@ def initialize_globals():
|
|||
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
if not os.path.exists(FLAGS.one_shot_infer):
|
||||
if not path_exists_remote(FLAGS.one_shot_infer):
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import requests
|
|||
import progressbar
|
||||
|
||||
from os import path, makedirs
|
||||
from .io import open_remote, path_exists_remote, is_remote_path
|
||||
|
||||
SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed']
|
||||
|
||||
|
@ -9,17 +10,18 @@ def maybe_download(archive_name, target_dir, archive_url):
|
|||
# If archive file does not exist, download it...
|
||||
archive_path = path.join(target_dir, archive_name)
|
||||
|
||||
if not path.exists(target_dir):
|
||||
if not is_remote_path(target_dir) and not path.exists(target_dir):
|
||||
print('No path "%s" - creating ...' % target_dir)
|
||||
makedirs(target_dir)
|
||||
|
||||
if not path.exists(archive_path):
|
||||
if not path_exists_remote(archive_path):
|
||||
print('No archive "%s" - downloading...' % archive_path)
|
||||
req = requests.get(archive_url, stream=True)
|
||||
total_size = int(req.headers.get('content-length', 0))
|
||||
done = 0
|
||||
with open(archive_path, 'wb') as f:
|
||||
with open_remote(archive_path, 'wb') as f:
|
||||
bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR)
|
||||
|
||||
for data in req.iter_content(1024*1024):
|
||||
done += len(data)
|
||||
f.write(data)
|
||||
|
|
|
@ -10,7 +10,7 @@ from attrdict import AttrDict
|
|||
|
||||
from .flags import FLAGS
|
||||
from .text import levenshtein
|
||||
|
||||
from .io import open_remote
|
||||
|
||||
def pmap(fun, iterable):
|
||||
pool = Pool()
|
||||
|
@ -124,5 +124,5 @@ def save_samples_json(samples, output_path):
|
|||
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
|
||||
in the texts.
|
||||
'''
|
||||
with open(output_path, 'w') as fout:
|
||||
with open_remote(output_path, 'w') as fout:
|
||||
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)
|
||||
|
|
|
@ -78,6 +78,32 @@ class Interleaved:
|
|||
return self.len
|
||||
|
||||
|
||||
class LenMap:
|
||||
"""
|
||||
Wrapper around python map() output object that preserves the original collection length
|
||||
by implementing __len__.
|
||||
"""
|
||||
def __init__(self, fn, iterable):
|
||||
try:
|
||||
self.length = len(iterable)
|
||||
except TypeError:
|
||||
self.length = None
|
||||
self.mapobj = map(fn, iterable)
|
||||
|
||||
def __iter__(self):
|
||||
self.mapobj = self.mapobj.__iter__()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.mapobj.__next__()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.mapobj.__getitem__(key)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class LimitingPool:
|
||||
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method
|
||||
before items get consumed by the iteration caller.
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
A set of I/O utils that allow us to open files on remote storage as if they were present locally and access
|
||||
into HDFS storage using Tensorflow's C++ FileStream API.
|
||||
Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets.
|
||||
"""
|
||||
import os
|
||||
from tensorflow.io import gfile
|
||||
|
||||
|
||||
def is_remote_path(path):
|
||||
"""
|
||||
Returns True iff the path is one of the remote formats that this
|
||||
module supports
|
||||
"""
|
||||
return path.startswith('gs://') or path.startswith('hdfs://')
|
||||
|
||||
|
||||
def path_exists_remote(path):
|
||||
"""
|
||||
Wrapper that allows existance check of local and remote paths like
|
||||
`gs://...`
|
||||
"""
|
||||
if is_remote_path(path):
|
||||
return gfile.exists(path)
|
||||
return os.path.exists(path)
|
||||
|
||||
|
||||
def copy_remote(src, dst, overwrite=False):
|
||||
"""
|
||||
Allows us to copy a file from local to remote or vice versa
|
||||
"""
|
||||
return gfile.copy(src, dst, overwrite)
|
||||
|
||||
|
||||
def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None):
|
||||
"""
|
||||
Wrapper around open() method that can handle remote paths like `gs://...`
|
||||
off Google Cloud using Tensorflow's IO helpers.
|
||||
|
||||
buffering, encoding, newline, closefd, and opener are ignored for remote files
|
||||
|
||||
This enables us to do:
|
||||
with open_remote('gs://.....', mode='w+') as f:
|
||||
do something with the file f, whether or not we have local access to it
|
||||
"""
|
||||
if is_remote_path(path):
|
||||
return gfile.GFile(path, mode=mode)
|
||||
return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener)
|
||||
|
||||
|
||||
def isdir_remote(path):
|
||||
"""
|
||||
Wrapper to check if remote and local paths are directories
|
||||
"""
|
||||
if is_remote_path(path):
|
||||
return gfile.isdir(path)
|
||||
return os.path.isdir(path)
|
||||
|
||||
|
||||
def listdir_remote(path):
|
||||
"""
|
||||
Wrapper to list paths in local dirs (alternative to using a glob, I suppose)
|
||||
"""
|
||||
if is_remote_path(path):
|
||||
return gfile.listdir(path)
|
||||
return os.listdir(path)
|
||||
|
||||
|
||||
def glob_remote(filename):
|
||||
"""
|
||||
Wrapper that provides globs on local and remote paths like `gs://...`
|
||||
"""
|
||||
return gfile.glob(filename)
|
||||
|
||||
|
||||
def remove_remote(filename):
|
||||
"""
|
||||
Wrapper that can remove local and remote files like `gs://...`
|
||||
"""
|
||||
# Conditional import
|
||||
return gfile.remove_remote(filename)
|
|
@ -8,7 +8,7 @@ import tarfile
|
|||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
|
||||
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
|
||||
from .audio import (
|
||||
Sample,
|
||||
DEFAULT_FORMAT,
|
||||
|
@ -18,6 +18,7 @@ from .audio import (
|
|||
get_audio_type_from_extension,
|
||||
write_wav
|
||||
)
|
||||
from .io import open_remote, is_remote_path
|
||||
|
||||
BIG_ENDIAN = 'big'
|
||||
INT_SIZE = 4
|
||||
|
@ -59,6 +60,37 @@ class LabeledSample(Sample):
|
|||
self.transcript = transcript
|
||||
|
||||
|
||||
class PackedSample:
|
||||
"""
|
||||
A wrapper that we can carry around in an iterator and pass to a child process in order to
|
||||
have the child process do the loading/unpacking of the sample, allowing for parallel file
|
||||
I/O.
|
||||
"""
|
||||
def __init__(self, filename, audio_type, label):
|
||||
self.filename = filename
|
||||
self.audio_type = audio_type
|
||||
self.label = label
|
||||
|
||||
def unpack(self):
|
||||
with open_remote(self.filename, 'rb') as audio_file:
|
||||
data = audio_file.read()
|
||||
if self.label is None:
|
||||
s = Sample(self.audio_type, data, sample_id=self.filename)
|
||||
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
|
||||
return s
|
||||
|
||||
|
||||
def unpack_maybe(sample):
|
||||
"""
|
||||
Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already.
|
||||
"""
|
||||
if hasattr(sample, 'unpack'):
|
||||
realized_sample = sample.unpack()
|
||||
else:
|
||||
realized_sample = sample
|
||||
return realized_sample
|
||||
|
||||
|
||||
def load_sample(filename, label=None):
|
||||
"""
|
||||
Loads audio-file as a (labeled or unlabeled) sample
|
||||
|
@ -69,21 +101,19 @@ def load_sample(filename, label=None):
|
|||
Filename of the audio-file to load as sample
|
||||
label : str
|
||||
Label (transcript) of the sample.
|
||||
If None: return util.audio.Sample instance
|
||||
Otherwise: return util.sample_collections.LabeledSample instance
|
||||
If None: returned result.unpack() will return util.audio.Sample instance
|
||||
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
||||
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
|
||||
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
||||
"""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
audio_type = get_audio_type_from_extension(ext)
|
||||
if audio_type is None:
|
||||
raise ValueError('Unknown audio type extension "{}"'.format(ext))
|
||||
with open(filename, 'rb') as audio_file:
|
||||
if label is None:
|
||||
return Sample(audio_type, audio_file.read(), sample_id=filename)
|
||||
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
|
||||
return PackedSample(filename, audio_type, label)
|
||||
|
||||
|
||||
class DirectSDBWriter:
|
||||
|
@ -119,7 +149,7 @@ class DirectSDBWriter:
|
|||
raise ValueError('Audio type "{}" not supported'.format(audio_type))
|
||||
self.audio_type = audio_type
|
||||
self.bitrate = bitrate
|
||||
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
|
||||
self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering)
|
||||
self.offsets = []
|
||||
self.num_samples = 0
|
||||
|
||||
|
@ -215,7 +245,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
|||
"""
|
||||
self.sdb_filename = sdb_filename
|
||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
|
||||
self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
|
||||
self.offsets = []
|
||||
if self.sdb_file.read(len(MAGIC)) != MAGIC:
|
||||
raise RuntimeError('No Sample Database')
|
||||
|
@ -332,6 +362,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
|||
labeled : bool or None
|
||||
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
||||
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
||||
|
||||
Currently only works with local files (not gs:// or hdfs://...)
|
||||
"""
|
||||
self.csv_filename = Path(csv_filename)
|
||||
self.csv_base_dir = self.csv_filename.parent.resolve().absolute()
|
||||
|
@ -345,7 +377,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
|||
self.labeled = labeled
|
||||
if labeled:
|
||||
fieldnames.append('transcript')
|
||||
self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='')
|
||||
self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='')
|
||||
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
|
||||
self.csv_writer.writeheader()
|
||||
self.counter = 0
|
||||
|
@ -380,7 +412,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
|||
|
||||
|
||||
class TarWriter: # pylint: disable=too-many-instance-attributes
|
||||
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
|
||||
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
|
||||
def __init__(self,
|
||||
tar_filename,
|
||||
gz=False,
|
||||
|
@ -398,6 +430,8 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
|
|||
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
||||
include : str[]
|
||||
List of files to include into tar root.
|
||||
|
||||
Currently only works with local files (not gs:// or hdfs://...)
|
||||
"""
|
||||
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
|
||||
samples_dir = tarfile.TarInfo('samples')
|
||||
|
@ -498,8 +532,7 @@ class CSV(SampleList):
|
|||
If the order of the samples should be reversed
|
||||
"""
|
||||
rows = []
|
||||
csv_dir = Path(csv_filename).parent
|
||||
with open(csv_filename, 'r', encoding='utf8') as csv_file:
|
||||
with open_remote(csv_filename, 'r', encoding='utf8') as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
if 'transcript' in reader.fieldnames:
|
||||
if labeled is None:
|
||||
|
@ -508,9 +541,12 @@ class CSV(SampleList):
|
|||
raise RuntimeError('No transcript data (missing CSV column)')
|
||||
for row in reader:
|
||||
wav_filename = Path(row['wav_filename'])
|
||||
if not wav_filename.is_absolute():
|
||||
wav_filename = csv_dir / wav_filename
|
||||
wav_filename = str(wav_filename)
|
||||
if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']):
|
||||
wav_filename = Path(csv_filename).parent / wav_filename
|
||||
wav_filename = str(wav_filename)
|
||||
else:
|
||||
# Pathlib otherwise removes a / from filenames like hdfs://
|
||||
wav_filename = row['wav_filename']
|
||||
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
|
||||
if labeled:
|
||||
rows.append((wav_filename, wav_filesize, row['transcript']))
|
||||
|
@ -554,6 +590,11 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
|
|||
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
|
||||
keep default sample order from shortest to longest.
|
||||
|
||||
Note that when using distributed training, it is much faster to call this function with single pre-
|
||||
sorted sample source, because this allows for parallelization of the file I/O. (If this function is
|
||||
called with multiple sources, the samples have to be unpacked on a single parent process to allow
|
||||
for reading their durations.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sample_sources : list of str
|
||||
|
@ -570,13 +611,20 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
|
|||
|
||||
Returns
|
||||
-------
|
||||
iterable of util.sample_collections.LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
|
||||
iterable of util.sample_collections.PackedSample if a single collection is provided, wrapping
|
||||
LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
|
||||
or LabeledSample / util.audio.Sample directly, if multiple collections are provided
|
||||
"""
|
||||
sample_sources = list(sample_sources)
|
||||
if len(sample_sources) == 0:
|
||||
raise ValueError('No files')
|
||||
if len(sample_sources) == 1:
|
||||
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
|
||||
cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
|
||||
for source in sample_sources]
|
||||
|
||||
# If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should
|
||||
# be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code.
|
||||
cols = [LenMap(
|
||||
unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse))
|
||||
for source in sample_sources]
|
||||
|
||||
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)
|
||||
|
|
Загрузка…
Ссылка в новой задаче